tensorflow 读取模型并进行预测

tensorflow 读取两种格式的模型并进行预测

1. 模型保存

1.1 checkpoint 模型

如图所示,
.meta – 保存图结构,即神经网络的网络结构
.data – 保存数据文件,即网络的权值,偏置,操作等等
.index – 是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
checkpoint – 文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model.

保存模型:

saver = tf.train.Saver()
saver.save(sess, model_path)

其中model_path是模型保存路径。

1.2 frozen_graph模型

在工程中,我们往往需要将模型和权重固化,便于发布和预测。
使用tensorFlow官方提供的freeze_graph.py工具来保存相应模型。(代码中把freeze_graph.py文件放在commom.utils.tf路径下导入)

freeze_graph.py先加载模型文件,从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重常量,然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)。

from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
    sess.graph.as_graph_def(),
    os.path.join(model_path),
    GRAPH_PB_NAME,
    as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
    input_graph=os.path.join(model_path, GRAPH_PB_NAME),
    input_saver=False,
    input_binary=True,
    input_checkpoint=os.path.</
  • 6
    点赞
  • 59
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值