tensorflow---ckpt模型转pb模型&pb模型的读取

前言:

将深度学习算法模型部署在移动端的时候,往往需要对tf的ckpt模型做冰冻(freeze),这时候就牵扯到将ckpt模型转pb模型,并将pb模型进行读取。

代码实现:

1.将ckpt模型转为pb模型:

import tensorflow as tf
from tensorflow.python.framework import graph_util
import cv2
from utils import tools
import numpy as np
from model.head.yolov3 import YOLOV3  #加载自己的模型
from tensorflow.python.tools import freeze_graph

pb_file = "./yolov3.pb"
ckpt_file = "./ckpt/yolo.ckpt"
output_node_names = ['YoloV3/pred_sbbox/concat_2:0', 'YoloV3/pred_mbbox/concat_2:0', 'YoloV3/pred_lbbox/concat_2:0']
output_node_names_1 = ['YoloV3/pred_sbbox/concat_2', 'YoloV3/pred_mbbox/concat_2', 'YoloV3/pred_lbbox/concat_2']


with tf.name_scope('input'):
    input_data = tf.placeholder(dtype=tf.float32,shape=[None,544,544,3], name='input_data')
    training = tf.placeholder_with_default(False,shape=(),name='training')


_, _, _, pred_sbbox, pred_mbbox, pred_lbbox = YOLOV3(training).build_nework(input_data,False)

print(pred_sbbox, pred_mbbox,pred_lbbox )


sess  = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)

image = cv2.imread("./test.png")
yolo_input = tools.img_preprocess2(image, None, (544, 544), False)
yolo_input = yolo_input[np.newaxis, ...]

pred_sbbox, pred_mbbox, pred_lbbox = sess.run(output_node_names, feed_dict={input_data: yolo_input, training: False})

print(pred_sbbox.shape)
print(pred_mbbox.shape)
print(pred_lbbox.shape)


gd = sess.graph.as_graph_def()

for node in gd.node:
    if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in xrange(len(node.input)):
            if 'moving_' in node.input[index]:
                node.input[index] = node.input[index] + '/read'
    elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']
    elif node.op == 'AssignAdd':
        node.op = 'Add'
        if 'use_locking' in node.attr: del node.attr['use_locking']
    elif node.op == 'Assign':
      node.op = 'Identity'
      if 'use_locking' in node.attr: del node.attr['use_locking']
      if 'validate_shape' in node.attr: del node.attr['validate_shape']
      if len(node.input) == 2:
        # input0: ref: Should be from a Variable node. May be uninitialized.
        # input1: value: The value to be assigned to the variable.
        node.input[0] = node.input[1]
        del node.input[1]


converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_node_names_1)


with tf.gfile.GFile(pb_file, "wb") as f:
    f.write(converted_graph_def.SerializeToString())

2.读取pb模型

import tensorflow as tf
import cv2
from utils import tools
import numpy as np

pb_file = "./yolov3.pb"
output_node_names = ['YoloV3/pred_sbbox/concat_2:0', 'YoloV3/pred_mbbox/concat_2:0', 'YoloV3/pred_lbbox/concat_2:0']


sess = tf.Session()


with tf.gfile.FastGFile(pb_file, 'rb') as f:
    gd = tf.GraphDef()
    gd.ParseFromString(f.read())
    sess.graph.as_default()

    for node in gd.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr: del node.attr['use_locking']
            if 'validate_shape' in node.attr: del node.attr['validate_shape']
            if len(node.input) == 2:
                # input0: ref: Should be from a Variable node. May be uninitialized.
                # input1: value: The value to be assigned to the variable.
                node.input[0] = node.input[1]
                del node.input[1]
    tf.import_graph_def(gd, name='')

input_data = sess.graph.get_tensor_by_name('input/input_data:0')
training = sess.graph.get_tensor_by_name('input/training:0')

sess.run(tf.global_variables_initializer())

# output_node_names_1 = sess.graph.get_tensor_by_name(','.join(output_node_names))
image = cv2.imread("./test.png")
yolo_input = tools.img_preprocess2(image, None, (544, 544), False)
yolo_input = yolo_input[np.newaxis, ...]

pred_sbbox, pred_mbbox, pred_lbbox = sess.run(output_node_names, feed_dict={input_data: yolo_input, training: 0})

print(pred_sbbox.shape)
print(pred_mbbox.shape)
print(pred_lbbox.shape)

致谢:

感谢大佬lh的远程帮助,后续仔细阅读后再做详细说明。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值