深度学习模型ckpt转成pb模型(YOLOv3为例)

为了深度学习模型的移植,首先需要将训练好的模型.ckpt的三个模型保存成.pb模型,在网上找到了很多方法,但是困难重重,中间经历找不到输入输出node,找到之后输出的模型不能进行预测,后来终于找到了方法,这里记录一下。我成功啦~
步骤如下:

1.首先获得输入输出节点的名字

代码如下:

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
with tf.name_scope('input'):
    img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')  # 输入节点
model = YOLOV3(img_input, False)  # mYOLOv3模型
print(img_input)#输出输入节点名称
print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)#输出输出接电脑名称

会输出:

Tensor("input/input_data:0", shape=(?, 416, 416, 3), dtype=float32)
Tensor("pred_sbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_mbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_lbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32)

2.转换成pb文件

#! /usr/bin/env python
# coding=utf-8
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
output_node_names = ["pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]#输出节点
ckpt_filename = "./pbmodel/new/yolov3.ckpt" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
pb_file       = "./pbmodel/yolov3_new.pb"
def ckpt2pb2():
    with tf.Graph().as_default() as graph_old:
        with tf.name_scope('input'):
            img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')#输入节点
        model = YOLOV3(img_input, False)#mYOLOv3模型
        print(img_input)
        print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)

        isess = tf.InteractiveSession()

        isess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(isess, ckpt_filename)

        constant_graph = tf.graph_util.convert_variables_to_constants(isess, isess.graph_def,
                                                                      output_node_names)
        constant_graph = tf.graph_util.remove_training_nodes(constant_graph)

        with tf.gfile.GFile(pb_file, mode='wb') as f:
            f.write(constant_graph.SerializeToString())
        print("%d ops in the final graph." % len(constant_graph.node))  # 得到当前图有几个操作节点
if __name__ == '__main__':
    ckpt2pb2()

3.使用tensorboard查看pb文件的图

1.先使用如下代码读取pb文件,在pblog文件中生成了event,然后再tensorboard上查看图

import tensorflow as tf
with tf.Session() as sess:
    model_filename ='./pbmodel/yolov3_new.pb'#模型路径
    with tf.gfile.GFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
        train_writer = tf.summary.FileWriter("./pblog")#保存
        train_writer.add_graph(sess.graph)
        train_writer.flush()
        train_writer.close()

2.打开prompt,输入tensorboard --logdir=“路径”,然后就可以看啦。

在这里插入图片描述

4.使用pb模型进行预测

我这里使用的是project的demo文件,直接修改了pb的名称
在这里插入图片描述

总结

娃哈哈,四天的时间,第一天不知道熟悉代码,第二天查资料如果打开ckpt并知道怎么得到输入输出的节点名称,第三天终于能够生成了pb文件,但是不能进行预测,第四天,换了新的方法,终于成功啦!!!哈哈哈哈哈

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值