【TensorFlow系列】【三】冻结模型文件并做inference
这篇文章是一个非常简洁的例子,快速上手。
要点:
- 获取ckpt模型的节点名称
import os
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join('./ade20k', "model.ckpt-27150")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
# print(reader.get_tensor(key)) #相应的值
- 获取pb模型的节点名称
import tensorflow as tf
import os
model_dir = './'
model_name = 'model.pb'
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
- ckpt转换为pb模型
from tensorflow.python.tools import inspect_checkpoint as chkp
import tensorflow as tf
saver = tf.train.import_meta_graph("./ade20k/model.ckpt-27150.meta", clear_devices=True)
#【敲黑板!】这里就是填写输出节点名称惹
output_nodes = ["xxx"]
with tf.Session(graph=tf.get_default_graph()) as sess:
input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, "./ade20k/model.ckpt-27150")
output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
input_graph_def,
output_nodes)
with open("frozen_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
要点:
for op in graph.get_operations():
print(op.name,op.values())
# prefix/Placeholder/inputs_placeholder
# ...
# prefix/Accuracy/predictions
#操作有:prefix/Placeholder/inputs_placeholder
#操作有:prefix/Accuracy/predictions
#为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字
#注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字
x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')
要点:
- 用
export_inference_graph.py
和freeze_graph.py
来固化模型 - 使用bazel工具进行固化
- pb文件转tflite文件
TensorFlow: How to freeze a model and serve it with a python API
上面那篇文章其实总结了这篇文章
Saving, Freezing, Optimizing for inference, Restoring of tensorflow models
这篇文章初看好像有点多此一举的步骤,但还没细看
要点:
- 用
tf.train.write_graph
和freeze_graph.freeze_graph
来固化模型 optimize_for_inference_lib.optimize_for_inference
用途:- Removing operations used only for training like checkpoint saving.
- Stripping out parts of the graph that are never reached.
- Removing debug operations like CheckNumerics.
tf.contrib.layers.flatten or tf.reshape
这个也是用到了tf.train.write_graph
和freeze_graph.freeze_graph
来固化模型,但没有上面那个文章好。
查看tensorflow 模型文件的节点信息
查看tensorflow pb模型文件的节点信息
第一篇的一半来自于第二篇
Update [01/07/2019]
固化模型的另一种方式:使用官方脚本 tensorflow/python/tools/freeze_graph.py
步骤
- 转换 图的定义 和 checkpoint 模型:
""" This script is a mixture of
https://github.com/YunYang1994/tensorflow-yolov3/blob/master/freeze_graph.py
and
https://gist.github.com/domluna/ed477cb5698c787f29c7d56fba381fed
(redirected from https://github.com/tensorflow/tensorflow/issues/10663)
"""
import os
import tensorflow as tf
from core.yolov3 import YOLOV3
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str()
pb_file = "./yolov3_coco.pb"
ckpt_file = "./checkpoint/yolov3_coco_demo.ckpt"
output_node_names = ["input/input_data", "pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32, name='input_data')
model = YOLOV3(input_data, trainable=False)
print(model.conv_sbbox, model.conv_mbbox, model.conv_lbbox)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)
tf.train.write_graph(sess.graph.as_graph_def(), './freeze_model/', "graph.pb", as_text=False)
- 固化模型
# Batch script
python C:\Python27\Lib\site-packages\tensorflow\python\tools\freeze_graph.py ^
--input_graph=.\freeze_model\graph.pb ^
--output_graph=.\freeze_model\frozen.pb ^
--input_checkpoint=.\checkpoint\yolov3_coco_demo.ckpt ^
--output_node_names=input/input_data,pred_sbbox/concat_2,pred_mbbox/concat_2,pred_lbbox/concat_2 ^
--input_binary=true