前言:
将深度学习算法模型部署在移动端的时候,往往需要对tf的ckpt模型做冰冻(freeze),这时候就牵扯到将ckpt模型转pb模型,并将pb模型进行读取。
代码实现:
1.将ckpt模型转为pb模型:
import tensorflow as tf
from tensorflow.python.framework import graph_util
import cv2
from utils import tools
import numpy as np
from model.head.yolov3 import YOLOV3 #加载自己的模型
from tensorflow.python.tools import freeze_graph
pb_file = "./yolov3.pb"
ckpt_file = "./ckpt/yolo.ckpt"
output_node_names = ['YoloV3/pred_sbbox/concat_2:0', 'YoloV3/pred_mbbox/concat_2:0', 'YoloV3/pred_lbbox/concat_2:0']
output_node_names_1 = ['YoloV3/pred_sbbox/concat_2', 'YoloV3/pred_mbbox/concat_2', 'YoloV3/pred_lbbox/concat_2']
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32,shape=[None,544,544,3], name='input_data')
training = tf.placeholder_with_default(False,shape=(),name='training')
_, _, _, pred_sbbox, pred_mbbox, pred_lbbox = YOLOV3(training).build_nework(input_data,False)
print(pred_sbbox, pred_mbbox,pred_lbbox )
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)
image = cv2.imread("./test.png")
yolo_input = tools.img_preprocess2(image, None, (544, 544), False)
yolo_input = yolo_input[np.newaxis, ...]
pred_sbbox, pred_mbbox, pred_lbbox = sess.run(output_node_names, feed_dict={input_data: yolo_input, training: False})
print(pred_sbbox.shape)
print(pred_mbbox.shape)
print(pred_lbbox.shape)
gd = sess.graph.as_graph_def()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'Assign':
node.op = 'Identity'
if 'use_locking' in node.attr: del node.attr['use_locking']
if 'validate_shape' in node.attr: del node.attr['validate_shape']
if len(node.input) == 2:
# input0: ref: Should be from a Variable node. May be uninitialized.
# input1: value: The value to be assigned to the variable.
node.input[0] = node.input[1]
del node.input[1]
converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_node_names_1)
with tf.gfile.GFile(pb_file, "wb") as f:
f.write(converted_graph_def.SerializeToString())
2.读取pb模型
import tensorflow as tf
import cv2
from utils import tools
import numpy as np
pb_file = "./yolov3.pb"
output_node_names = ['YoloV3/pred_sbbox/concat_2:0', 'YoloV3/pred_mbbox/concat_2:0', 'YoloV3/pred_lbbox/concat_2:0']
sess = tf.Session()
with tf.gfile.FastGFile(pb_file, 'rb') as f:
gd = tf.GraphDef()
gd.ParseFromString(f.read())
sess.graph.as_default()
for node in gd.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'Assign':
node.op = 'Identity'
if 'use_locking' in node.attr: del node.attr['use_locking']
if 'validate_shape' in node.attr: del node.attr['validate_shape']
if len(node.input) == 2:
# input0: ref: Should be from a Variable node. May be uninitialized.
# input1: value: The value to be assigned to the variable.
node.input[0] = node.input[1]
del node.input[1]
tf.import_graph_def(gd, name='')
input_data = sess.graph.get_tensor_by_name('input/input_data:0')
training = sess.graph.get_tensor_by_name('input/training:0')
sess.run(tf.global_variables_initializer())
# output_node_names_1 = sess.graph.get_tensor_by_name(','.join(output_node_names))
image = cv2.imread("./test.png")
yolo_input = tools.img_preprocess2(image, None, (544, 544), False)
yolo_input = yolo_input[np.newaxis, ...]
pred_sbbox, pred_mbbox, pred_lbbox = sess.run(output_node_names, feed_dict={input_data: yolo_input, training: 0})
print(pred_sbbox.shape)
print(pred_mbbox.shape)
print(pred_lbbox.shape)
致谢:
感谢大佬lh的远程帮助,后续仔细阅读后再做详细说明。