为了深度学习模型的移植,首先需要将训练好的模型.ckpt的三个模型保存成.pb模型,在网上找到了很多方法,但是困难重重,中间经历找不到输入输出node,找到之后输出的模型不能进行预测,后来终于找到了方法,这里记录一下。我成功啦~
步骤如下:
1.首先获得输入输出节点的名字
代码如下:
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
with tf.name_scope('input'):
img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data') # 输入节点
model = YOLOV3(img_input, False) # mYOLOv3模型
print(img_input)#输出输入节点名称
print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)#输出输出接电脑名称
会输出:
Tensor("input/input_data:0", shape=(?, 416, 416, 3), dtype=float32)
Tensor("pred_sbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_mbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_lbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32)
2.转换成pb文件
#! /usr/bin/env python
# coding=utf-8
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
output_node_names = ["pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]#输出节点
ckpt_filename = "./pbmodel/new/yolov3.ckpt" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
pb_file = "./pbmodel/yolov3_new.pb"
def ckpt2pb2():
with tf.Graph().as_default() as graph_old:
with tf.name_scope('input'):
img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')#输入节点
model = YOLOV3(img_input, False)#mYOLOv3模型
print(img_input)
print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)
isess = tf.InteractiveSession()
isess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(isess, ckpt_filename)
constant_graph = tf.graph_util.convert_variables_to_constants(isess, isess.graph_def,
output_node_names)
constant_graph = tf.graph_util.remove_training_nodes(constant_graph)
with tf.gfile.GFile(pb_file, mode='wb') as f:
f.write(constant_graph.SerializeToString())
print("%d ops in the final graph." % len(constant_graph.node)) # 得到当前图有几个操作节点
if __name__ == '__main__':
ckpt2pb2()
3.使用tensorboard查看pb文件的图
1.先使用如下代码读取pb文件,在pblog文件中生成了event,然后再tensorboard上查看图
import tensorflow as tf
with tf.Session() as sess:
model_filename ='./pbmodel/yolov3_new.pb'#模型路径
with tf.gfile.GFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
train_writer = tf.summary.FileWriter("./pblog")#保存
train_writer.add_graph(sess.graph)
train_writer.flush()
train_writer.close()
2.打开prompt,输入tensorboard --logdir=“路径”,然后就可以看啦。
4.使用pb模型进行预测
我这里使用的是project的demo文件,直接修改了pb的名称
总结
娃哈哈,四天的时间,第一天不知道熟悉代码,第二天查资料如果打开ckpt并知道怎么得到输入输出的节点名称,第三天终于能够生成了pb文件,但是不能进行预测,第四天,换了新的方法,终于成功啦!!!哈哈哈哈哈