软件版本:
tensorflow == 1.12.0
python == 3.6.1
关于cpkt模型的加载与使用可以看这个
关于.pb模型
不同于cpkt的是.pb模型将模型参数和网络结构固化在同一个文件中,使用中无需分别读取参数和结构,通过tensorflow自带函数可一并读取至graph,之后操作graph即可完成运算。
模型保存
https://blog.csdn.net/u014568072/article/details/85281769
模型读取
首先将模型文件存放至代码的同级目录,我的模型文件名为yolov4-tiny.pb,所以将其输入GFile函数并设置mode为’rb’,运行下列代码。之后模型被存入变量graph。
with tf.gfile.GFile('yolov4-tiny.pb', "rb") as pb:
graph_def = tf.GraphDef()
graph_def.ParseFromString(pb.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name="", # name可以自定义,修改name之后记得在下面的代码中也要改过来
)
for op in graph.get_operations():
print(op.name, op.values()) # 打印网络结构
观察打印出的网络结构如下所示,左侧如“inputs”为operation名称,括号内如“inputs:0”为tensor名称,接下来使用get_tensor_by_name时要使用tensor名称。
inputs (<tf.Tensor 'inputs:0' shape=(?, 416, 416, 3) dtype=float32>,)
detector/truediv/y (<tf.Tensor 'detector/truediv/y:0' shape=() dtype=float32>,)
detector/truediv (<tf.Tensor 'detector/truediv:0' shape=(?, 416, 416, 3) dtype=float32>,)
detector/yolo-v4-tiny/Pad/paddings (<tf.Tensor 'detector/yolo-v4-tiny/Pad/paddings:0' shape=(4, 2) dtype=int32>,)
detector/yolo-v4-tiny/Pad (<tf.Tensor 'detector/yolo-v4-tiny/Pad:0' shape=(?, 418, 418, 3) dtype=float32>,)
detector/yolo-v4-tiny/Conv/weights (<tf.Tensor 'detector/yolo-v4-tiny/Conv/weights:0' shape=(3, 3, 3, 32) dtype=float32>,)
detector/yolo-v4-tiny/Conv/weights/read (<tf.Tensor 'detector/yolo-v4-tiny/Conv/weights/read:0' shape=(3, 3, 3, 32) dtype=float32>,)
detector/yolo-v4-tiny/Conv/Conv2D (<tf.Tensor 'detector/yolo-v4-tiny/Conv/Conv2D:0' shape=(?, 208, 208, 32) dtype=float32>,)
…………
模型运行
之后使用get_tensor_by_name获取输入和输出节点,并运行session即可得到结果
node_in = graph.get_tensor_by_name('inputs:0') # 此处填入输入节点名称
node_out = graph.get_tensor_by_name('detector/yolo-v4-tiny/Reshape_4:0') # 此处填入输出节点名称
with tf.Session(graph=graph) as sess: # Session()别忘了传入参数!
# sess.run(tf.global_variables_initializer()) # 因为是从模型中读取,所以无需初始化变量
feed_dict = {node_in: image} # image为node_in输入数据,有关代码已省略
pred = sess.run(node_out, feed_dict) #运行session,得到node_out
print(pred)
sess.close()
需要特别注意的一点是,tf.Session(graph=graph)要将graph作为参数传入!否则将会报错提示operation不存在!
运行程序可以看到pred值被成功打印
参考文章
https://www.jb51.net/article/179098.htm
https://blog.csdn.net/lujiandong1/article/details/53385092