tensorflow 读取两种格式的模型并进行预测
文章目录
1. 模型保存
1.1 checkpoint 模型
如图所示,
.meta
– 保存图结构,即神经网络的网络结构
.data
– 保存数据文件,即网络的权值,偏置,操作等等
.index
– 是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
checkpoint
– 文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model.
保存模型:
saver = tf.train.Saver()
saver.save(sess, model_path)
其中model_path
是模型保存路径。
1.2 frozen_graph模型
在工程中,我们往往需要将模型和权重固化,便于发布和预测。
使用tensorFlow
官方提供的freeze_graph.py
工具来保存相应模型。(代码中把freeze_graph.py
文件放在commom.utils.tf
路径下导入)
freeze_graph.py
先加载模型文件,从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重常量,然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)。
from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
sess.graph.as_graph_def(),
os.path.join(model_path),
GRAPH_PB_NAME,
as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
input_graph=os.path.join(model_path, GRAPH_PB_NAME),
input_saver=False,
input_binary=True,
input_checkpoint=os.path.</