tensorflow实现quantization-aware training(伪量化,fake quantization)

前面一篇文章讲模型优化的时候有讲到量化模型,但那只是量化权重,在实际计算的时候还是会反量化回去,用float32位计算,没有进行实际意义上的定点运算。今天讲的这个方式是可以部署在移动端进行定点运算的,乘现在网上关于这方面资料很少,赶紧写一篇,求赞呀~~~

源代码位置:tensorflow/contrib/quantize/
github参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize
tensorflow实例参考:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands

为啥叫伪量化?

因为它只是通过在训练时向某些能识别的操作中加入fake_quantization_node,用于统计该节点的最大值,最小值。这里统计的最大值,最小值用于后面用toco工具完全量化操作,从而减小量化操作带来的精度损失。

注意: 某些网络的某些特殊操作目前还不支持自动向图中加入fake_quantization_node统计最大,最小值,需要自己手动加入节点统计,统计得不准会带来精度大大地下降,慎用,(如果有些节点在用toco转换的过程需要用到最大值最小值,而模型在训练过程中又没有插入fake_quantization_node自动统计,它会提示你需要指定默认的最大值,最小值。)。

具体步骤

第一步: 在train.py中的loss之后,train_op之前加入tf.contrib.quantize.create_training_graph(input_graph=tf.get_default_graph(), quant_delay=20000)
训练模型,保存ckpt文件。
注意,这里的quant_delay是训练迭代多少次后,网络开始做量化统计最大值,最小值,并用8bit做反向传播更新梯度。
如果你之前有训练好一个完整的模型,可直接加载这个模型进来做微调,这时可设置quant_delay=0,在定义saver的时候,
用saver = tf.train.Saver(tf.global_variables())或者saver=tf.train.Saver()较为保险,不然后面在freeze.py中加载模型进去会说有些节点没有权重初始化。(意思是说图中的有些节点,ckpt中没有保存参数,我就碰到过这种情况)

第二步: 在freeze.py中,构建你的inference_graph,加入tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()),restore之前训练保存的ckpt文件,冻结生成pb文件。

第三步: 把pb文件转换成tflite文件,我这里是调用的tensorflow提供的python API,代码如下:

import tensorflow as tf
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
graph_def_file = "train_results/eval_model.pb"
input_arrays = ["input"]
output_arrays = ["softmax"]
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
converter.quantized_input_stats = {input_arrays[0]: (73.0, 10.00667)}    # mean, std_dev,需要自己从训练集(增强后,输入网络之前的)统计出来
tflite_model = converter.convert()
open("train_results/freeze_models/converted_model.tflite", "wb").write(tflite_model)

在这里插入图片描述
最后能生成converted_model.tflite文件,大小约为转换之前的eval_model.pb的1/4左右。

第四步 测试converted_model.tflite

import numpy as np
import tensorflow as tf
import scipy
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="train_results/freeze_models/converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)

image_origin = scipy.misc.imread("src/151105230861_0_76.536.jpg", mode='RGB')   
image = tf.image.per_image_standardization(image_origin)
image_q = (image + 7.288) * 10.006671114      #这些参数通过打印input_details可以看到
image_ = np.array([image_q.astype('uint8')])

print(image_.shape)
print(type(image_))

interpreter.set_tensor(input_details[0]['index'], image_)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
print(output_data.shape, type(output_data))

def f(x):
    return 0.0078125*(x-128)        # 这些参数也是在output_details里可以看到

output_data_new = map(f, output_data[0])    #转换后output_data_new为浮点数,你可以和没量化的模型输出对比一下相似度

写在后面的话

可通过netron查看模型文件的结构。
我的eval_model.pb模型结构截图
从图中可以看到有很多FakeQuantWithMinMaxVars节点,并还有相应的最大值,最小值,那都是模型在训练时,自动统计的。
我的tflite模型结构截图:
从图中可以看到输入输出节点都有量化前后的转换参数及浮点数的范围。
最后补充
我之前用pb转tflite是用toco工具的,但始终报错,在这里折腾了几周,我的是tensorflow1.12版本的,报错信息如下:
这个错误信息应该是说不支持pooling操作???
最后偶然看到tensorflow社区里的帖子,居然还有python的API接口,遂用api试一下转tflite,于是成功了,但测试的精度不咋地,那是因为我的均值和方差没有统计。后面怎么统计的呢,我是在训练过程中的图片input节点那里插入了两个节点:

image_max = tf.reduce_max(image_batch, name='image_max')
image_min = tf.reduce_min(image_batch, name='image_min')

用于统计图片的最大值和最小值,然后mean=255min/(min-max),std_dev=255/max-min

踩过的坑:
训练时,用slim写的网络,is_train不要用placeholder,用了会导致某些节点不会插入fake_quantized_node,slim.fully_connected好像不会自动插入fake_quantized_node,或者说slim.fully_connected加了batchnorm不会自动插入fake_quantized_node统计最大值最小值。

参考
tensorflowLite的量化使用问题,帖子很长,慢慢看吧

  • 13
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 48
    评论
评论 48
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值