Saving, Freezing, Optimizing for inference, Restoring of tensorflow models
在训练完tensorflow模型后,会有三个文件:model-epoch_99.data-00000-of-00001,model-epoch_99.index,model-epoch_99.meta
1.tensorflowModel.ckpt.meta:Tenosrflow将图结构与变量值分开存储。 文件.ckpt.meta包含完整的图结构。 它包括GraphDef,SaverDef等。
2.tensorflowModel.ckpt.data-00000-of-00001:它包含的变量(重量,偏差,占位符,梯度,超参数等)的值。
3.tensorflowModel.ckpt.index:这是一个表,其中每个键是张量tensor的名称,其值是序列化的BundleEntryProto。
- 第一步先生成tensorflowModel.pbtxt文件。可以在测试程序中,执行完saver.restore之后,将graph保存为.pbtxt。
import resnet_multitask
def classify_model(images, class_num):
# images: 输入三通道彩色图
# class_num: 分类类别数目,用于定义网络最后的全连接层
with slim.arg_scope(resnet_multitask.resnet_arg_scope(is_training=False)):
logits, pre_heatmap, end_points = resnet_multitask.resnet_v2(images, class_num)
return logits, pre_heatmap, end_points
restore_path = './checkpoint/model-epoch_99'
with tf.Session() as sess:
input_x = tf.placeholder(tf.float32, shape=[None, w, h, c], name='input_x')
logits,pre_heatmap,end_points = classify_model(input_x,class_num)
saver = tf.train.Saver()
saver.restore(sess, restore_path)
## generate graph
tf.train.write_graph(sess.graph.as_graph_def(), '.', './checkpoint/tensorflowModel.pbtxt', as_text=True)
- 接下来就是生成.pb文件,参见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph('./checkpoint/tensorflowModel.pbtxt', "", False,
'./checkpoint/model-epoch_99_acc_0.968202', "resnet_v2/predictions/Reshape_1",
"save/restore_all", "save/Const:0",
'./checkpoint/model.pb', True, ""
)
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants.
Args:
input_graph: A `GraphDef` file to load.
input_saver: A TensorFlow Saver file.
input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
output_node_names: The name(s) of the output nodes, comma separated.
restore_op_name: Unused.
filename_tensor_name: Unused.
output_graph: String where to write the frozen `GraphDef`.
clear_devices: A Bool whether to remove device specifications.
initializer_nodes: Comma separated list of initializer nodes to run before
freezing.
variable_names_whitelist: The set of variable names to convert (optional, by
default, all variables are converted),
variable_names_blacklist: The set of variable names to omit converting
to constants (optional).
input_meta_graph: A `MetaGraphDef` file to load (optional).
input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
variables (optional).
saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
load, in string format.
checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
or saver_pb2.SaverDef.V2).
Returns:
String that is the location of frozen GraphDef.
"""
- 生成.pb文件后,可以通过tensorboard可视化
import tensorflow as tf
from tensorflow.python.platform import gfile
model = './checkpoint/model.pb'
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('./logs/model', graph)
写一个start_tensorboard.bat,内容如下,然后运行,打开浏览器,地址栏输入http://localhost:6006
cd C:\software\Anaconda3\Scripts
tensorboard.exe --logdir=C:\workspace\code\img_classify\logs\model
- 载入.pb模型进行前向运算
import tensorflow as tf
import numpy as np
import time
import cv2
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
input_x = sess.graph.get_tensor_by_name("input_x:0")
out_softmax = sess.graph.get_tensor_by_name("resnet_v2/predictions/Reshape_1:0")
img = cv2.imread(jpg_path)
img_ori = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
test_img = cv2.resize(img_ori, (224, 224))
test_img = np.asarray(test_img, np.float32)
test_img = test_img[np.newaxis, :] / 255.
time_start = time.time()
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:test_img})
time_end = time.time()
print('run time: ', time_end - time_start, 's')
print("img_out_softmax:",img_out_softmax)
prediction_labels = np.argmax(img_out_softmax)
print("label:",prediction_labels)
recognize(r'C:\data\test_image.jpg', "./checkpoint/model.pb")