2019第一篇,先祝大家新年快乐鸭~
本文使用的是Tensorflow Lite中自带的量化工具包,Github上官方代码,使用手册,我看到是18年12月才更新的工具包,import方式改变了,直接从tf.contrib.quantize中import,做的人还比较少,本文想先评估一下它的量化效果,也记录一下使用方法,因为其实官方没有给很多的demo指导。
TODO
- 在模型中加入量化工具
- 尝试在已训练好的模型加载量化
- 继续测评
打开一个训练好的graph(测试)
因为我之前做keras比较多,所以其实这里上手还是磨合了一番的。
准备环境:
python3.5
tensorflow-gpu==1.12.0
加上一个已经训练好,freeze过的graph
直接上代码,我说也说不清,参考一下这个
这里有一个很纠结的点就是这样打开.pb文件就是GraphDef而不是tf.Graph()文件,如果直接用量化函数打开graph_def就会报错 AttributeError: ‘GraphDef’ object has no attribute 'get_all_collection_keys’
import tensorflow as tf
from tensorflow.contrib.quantize import *
with tf.Session() as sess:
with gfile.FastGFile('./VGG16_freeze.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='vgg')
graph = sess.graph
out = create_training_graph(graph, 100)
但是呢,这里如果只按照我这个代码来其实还是不行的,报错如下
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/quantize/python/quantize.py", line 200, in QuantizeOpWithWeights
input_idx = next(i for i, v in enumerate(op.inputs)
StopIteration
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/quantize/python/quantize_graph.py", line 112, in create_training_graph
device_name_or_function=device_name_or_function)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/quantize/python/quantize_graph.py", line 68, in _create_graph
quantize.Quantize(g, is_training=is_training)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/quantize/python/quantize.py", line 101, in Quantize
context.QuantizeOpWithWeights(op, folded=False)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/quantize/python/quantize.py", line 204, in QuantizeOpWithWeights
raise ValueError('No inputs to quantize for op: %s' % op)
ValueError: No inputs to quantize for op: name: "vgg/block1_conv1/convolution"
这里很明显就是说只加载图是不行的,还要定义一下输入的节点,所以整个的代码应该要加上一下对应的tf.placeholder()
说明:
这里tf版本之间区别比较大,我之前用的是1.6.0,函数的定义和最新版本的完全不一样,不过也可以做,1.6.0版本中没有quant_delay参数&