Tensorflow 关于.pb 模型保存的方法和相关调试问题

  1. 模型保存介绍

在老版本的TensorFlow中,对训练后的模型框架和参数分别保存在后缀为.ckpt和.meta的文件中。然而在新版本的TensorFlow中,模型的保存为三个文件:.ckpt-data、.ckpt-meta、.ckpt-index,以及一个名为checkpoint的文件.

其中,checkpoint 文件的意义在于只是告知TF function 哪一个文件是最后更新的checkpoint文档;

.ckpt-data 保存了不含模型结构的模型参数;在python中恢复模型时,可联合使用meta和data文件,或者单独使用.pb文件。

saver = tf.train.import_meta_graph(path_to_ckpt_meta)

saver.restore(sess, path_to_ckpt_data)

.ckpt-meta 保存了metagraph, 即不包含模型参数的模型结构。

.ckpt-index 的意义暂不明确,估计是对上述两个文件内部的一种映射关系,但该文件在恢复模型时不是必须的。

 

.pb 文件相当于 meta + data. 在c++语言中加载该模型时使用。

 

  1. pb模型转换

1.调试中出现的问题与解决:

在保存pb模型前,尝试利用meta和data两个文件直接恢复模型的框架和参数,并对单幅图像做测试。然而出现如下问题:

(1).单幅图像的运行程序,为了提高推断效率,关闭了GPU的使用,os.environ[“CUDA_VISIBLE_DEVICES”]=”-1”,从而造成在恢复meta是无法加载,其原因是因为训练过程中使用的GPU,也许保存的模型框架也跟GPU产生关联;当注释掉上述语句,可顺利加载;

(2) 然而接下来,加载后meta 文件中保存的是训练模型的框架,各个参数矩阵维度结构包含当时训练的时候的batch,我设的batch==5,而现在如果要单幅图像测试的话,程序跑到run sess就挂起了,因为输入维度为[5, 55, 47, 3]而不是[1, 55, 47, 3]。解决方法是,不要使用之前已经保存的meta文件,而是在程序执行中在

with tf.Graph().as_default():下面直接调用模型框架函数,并定义模型的输入维度为[1, 55, 47, 3],这样使得输出向量的批处理数值与输入保持一致。同时加载data文件,接着利用语句:

tf.train.write_graph(sess.graph_def,pb_save_path,pb_save_name)

将模型保存为一个.pb文档文件,该文件保存的是模型的结构文档。(注意:这里虽然文件后缀也是.pb,但其不是freeze后的二进制pb文件,而是模型结构文档)

对应生成该pb文档的位置:/home/weihua/models/research/slim/inference_deepID-v1_pb.py

后面将会利用该pb文档和-ckpt.data 联合生成pb模型文件,从而避免了输入输出结构函数维度不匹配的问题。

(3)pb模型的保存

/home/weihua/git/tensorflow/tensorflow/tensorflow/python/tool/freeze_graph .py 文件, 输入参数如下:

----input_graph=/home/weihua/models/research/slim/polar_skeleton-models/deepID_v1/model_deepID_v1.pb

\ 该输入为刚才保存的模型

----input_checkpoint=/home/weihua/models/research/slim/polar_skeleton-models/deepID_v1/model.ckpt-30000

\该输入为模型中变量的参数值

----output_graph=/home/weihua/models/research/slim/polar_skeleton-models/deepID_v1/frozen_deepID_v1.pb

\该输出为整合后需要保存的名称及位置

--output_node_names=Softmax

 

这个比较重要,只从哪个节点开始保存。之前我想从预测节点开始往上保存,虽然预测节点不在deepID_v1卷积网络模型内,但pb文档保存了所有的模型框架,因此可以从该节点开始保存。预测节点对应的名称为:Softmax:0

(2)pb模型的测试

测试文件:/home/weihua/models/research/slim/inference_deepID-v1_pb_read.py

pb模型加载方法:

with tf.Graph().as_default():

output_graph_def = tf.GraphDef()

 

with open(pb_file_path, "rb") as f:

output_graph_def.ParseFromString(f.read())

_ = tf.import_graph_def(output_graph_def, name="")

其次,需要对模型的输入和输出接口的名称从加载的模型框架中读取:

logits = sess.graph.get_tensor_by_name("Softmax:0")

img_tensor = sess.graph.get_tensor_by_name("Placeholder:0")

然后sess run, 程序成功运行。

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值