如何插入8bit量化节点(tensorflow)

目录

tf流图graph基础知识

默认图

创建显式图

创建多个图

调用tf伪量化接口

插入kernel、层间量化节点


tf流图graph基础知识

默认图

import tensorflow as tf
import numpy as np

a = tf.constant(123)
print(a.graph)
print(tf.get_default_graph())

输出:

<tensorflow.python.framework.ops.Graph object at 0x7f32df766668>
<tensorflow.python.framework.ops.Graph object at 0x7f32df766668>

当tensorflow库被加载时,即使用户没有显示地创建一个图,他也会自动创建一个图对象,并将其作为默认的额数据流图

创建显式图

import tensorflow as tf
import numpy as np

g= tf.Graph() #创建了一个图对象g
with g.as_default(): #返回一个上下文管理器,使得当前图对象称为当前默认图对象
    a = tf.constant(123)
    print(a.graph) #a.graph:获得a所在的图
    print(tf.get_default_graph()) #get_default_graph:获取当前图对象的句柄

输出:

<tensorflow.python.framework.ops.Graph object at 0x7f71a7a2b710>
<tensorflow.python.framework.ops.Graph object at 0x7f71a7a2b710>

创建多个图

 

import tensorflow as tf 
import numpy as np

g1 = tf.Graph()
g2 = tf.Graph()

with g1.as_default():
    a = tf.constant(123)
    print(a.graph)
    print(tf.get_default_graph())
    
with g2.as_default():
    b = tf.multiply(2, 3)
    print(b.graph)
    print(tf.get_default_graph())

 输出:

<tensorflow.python.framework.ops.Graph object at 0x7f570354fcc0>
<tensorflow.python.framework.ops.Graph object at 0x7f570354fcc0>
<tensorflow.python.framework.ops.Graph object at 0x7f5684813860>
<tensorflow.python.framework.ops.Graph object at 0x7f5684813860>

 

调用tf伪量化接口

class Quantizationint8(object):
    #初始化量化参数
    #n为量化bit数8或16,d是实际小数点位数
    def __init__(self, n, d):
        d = float(d)
        self._quant_min = -(2 ** (n - 1) - 1) * (2 ** d)
        self._quant_max = (2 ** (n - 1) - 1) * (2 ** d)
        self._num_bits = n
        self._narrow = True
    
    #利用tf接口函数计算伪量化值,并返回
    def __call__(self, inputs):
        return array_ops.fake_quant_with_min_max_vars(inputs, self._quant_min, self._quant_max,
                                                      num_bits=self._num_bits, narrow_range=self._narrow)

tf接口释义

可参考:https://blog.csdn.net/weixin_36670529/article/details/100560469: 

 

插入kernel、层间量化节点

#g为静态图句柄,producer为全精度浮点数,quant_tensor为上面伪量化值节点的输出
def insert_slim_quant_op(g, producer, name, quant_tensor):
    Tar_op= []
    for (op_name, op) in g._nodes_by_name.items():
        if producer in op.inputs._inputs and name in op_name:
            Tar_op.append(op)
    assert len(Tar_op) != 0, "do not have the Tar_op of node {}\n".format(producer.name)
    with variable_scope.variable_scope(name + '/SlimQuant'):
        graph_editor.reroute_ts([quant_tensor], [producer], can_modify=Tar_op)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值