android object转实体类_Tensorflow模型量化4 --pb转tflite(uint8量化)小结

4033febc3b937919fcc63de0a4d1bce5.png

Tensorflow模型量化4 --pb转tflite小结(uint8量化)

  1. 实验环境:tensorflow-gpu1.15+cuda10.0

模型的fp16量化和int8量化我之前有写,参考:

龟龟:Tensorflow模型量化实践2--量化自己训练的模型​zhuanlan.zhihu.com
ad2a7b4b401b18c304ce1387ffa14fcc.png

这次发现uint8量化时有参数设置,所以准备是从头再梳理一遍

2.参与量化的模型:

训练tensorflow-object-detection API 得到的ssdlite_mobilenet _v2模型,导出为frozen_inference_graph.pb

3.获取输入输出节点

进行frozen_inference_graph.pb模型解析,得到输入输出节点信息

代码入下:

"""
code by zzg
"""
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = tf.ConfigProto() 
config.gpu_options.allow_growth = True 
 
with tf.Session() as sess:
    with open('frozen_inference_graph_resnet.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
 
        tf.import_graph_def(graph_def, name='')
        tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
        for tensor_name in tensor_name_list:
             print(tensor_name,'n')

之后找到输入节点在预处理之后入下所示:

460b2c68074041f6564c8562f6457058.png

找到输出节点在后处理之前,如下图所示:

227149b3e8ae5ff2cae6fd4ad346b9e1.png

4.量化(pb->tflite)

4.1方法一:利用TFLiteConverter

'''

4.2方法二:利用TOCO

toco  --graph_def_file 
./frozen_inference_graph.pb 
--output_file test.tflite 
--input_format=TENSORFLOW_GRAPHDEF 
--output_format=TFLITE 
--inference_type=QUANTIZED_UINT8 
--input_shape='1,300,300,3' --input_array='FeatureExtractor/MobilenetV2/MobilenetV2/input' --output_array='concat,concat_1' 
--std_dev_value 127.5 
--mean_value 127.5
--default_ranges_min 0 
--default_ranges_max 255

补充重点:uint8量化时的参数设置

01.由于是进行uint8量化,所以输出范围为[0,255]

即default_ranges_min =0,default_ranges_max=255

02.std_dev_value和mean_value参数

参考:https://www.cnblogs.com/sdu20112013/p/11960552.html

结论:
训练时模型的输入tensor的值在不同范围时,对应的mean_values,std_dev_values分别如下:

  • range (0,255) then mean = 0, std_dev = 1
  • range (-1,1) then mean = 127.5, std_dev = 127.5
  • range (0,1) then mean = 0, std_dev = 255

我查看了我的输入tensor范围是[-1,1], 所以我设置参数为 mean = 127.5, std_dev = 127.5

a2f5a00d24335a714fc1efa71ad96b28.png

5.tflite测试

在转换完成后,进行tflie解析测试,验证最后转换成功。

代码入下:

'''

最后显示如下:

c82931f541df9bcd7e86bfaf1b255bdf.png

补充:获取输入输出节点的话采用神经网络模型可视化工具Netron更加方便直观

参考:

模型结构可视化神器--Netron(支持tf, caffe, keras,mxnet等多种框架)​blog.csdn.net
9e1a792ba8183d97d83dab4e9ce06491.png
轻量好用的神经网络模型可视化工具netron_网络_Mingyong_Zhuang的技术博客-CSDN博客​blog.csdn.net
0557a92bb8abf1b9a9693d95d8e997cc.png

安装比较简单:windows直接安装.exe文件,linux直接 pip install netron即可

dfcd0b1e7136e7be6327e1fd74be9388.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值