将tensorflow的ckpt模型转化为pb模型,可以大大提高网络预测速度,是进行部署的第一步。怎么做参考:这里。我看网上资料较少,我写一下怎么读取pb模型进行测试,通常落地会采用c++这种更底层的语言。
具体怎么写需要根据网络的测试代码来写,每个网络输入输出不一样,我在下面贴一个写好的只作为参考。
总体步骤:
1.读入pb文件
def freeze_graph_test2(pb_path, test_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
这部分代码直接复制,tf的常规操作,把pb_path替换成自己的pb文件地址就行
2.定义输入
keep_probability = sess.graph.get_tensor_by_name(name="keep_probabilty:0")
image =sess.graph.get_tensor_by_name(name="input_image:0")
_, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
去看一下自己网络的测试代码,比如我的代码是这样:
所以用get_tensor_by_name取出这几个节点,注意后面必须加上索引号,如:0
3.定义输出
pred_annotation = sess.graph.get_tensor_by_name("inference/prediction:0")
inference/prediction:0是网络最后一层的名字
4.运行
到这步就跟测试代码一样就行了
pred = sess.run(pred_annotation, feed_dict={image: valid_images,keep_probability: 1.0})
所有代码:
import tensorflow as tf
from tensorflow.python.framework import graph_util
import os
import time
from datetime import timedelta
import numpy as np
import TensorflowUtils as utils
import read_MITSceneParsingData as scene_parsing
import datetime
import BatchDatsetReader as dataset
from six.moves import xrange
IMAGE_SIZE = 448
def freeze_graph_test2(pb_path, test_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
keep_probability = sess.graph.get_tensor_by_name(name="keep_probabilty:0")
image =sess.graph.get_tensor_by_name(name="input_image:0")
_, valid_records = scene_parsing.read_dataset(FLAGS.data_dir)
pred_annotation = sess.graph.get_tensor_by_name("inference/prediction:0")
image_options = {'resize': True, 'resize_size': IMAGE_SIZE}
validation_dataset_reader = dataset.BatchDatset(valid_records, image_options)
sess = tf.Session()
valid_images, valid_annotations = validation_dataset_reader.get_random_batch(FLAGS.batch_size)
pred = sess.run(pred_annotation, feed_dict={image: valid_images,keep_probability: 1.0})
print("len:",valid_annotations[0].shape)
valid_annotations = np.squeeze(valid_annotations, axis=3)
for itr in range(FLAGS.batch_size):
utils.save_image(valid_images[itr].astype(np.uint8), FLAGS.logs_dir, name="inp_" + str(5+itr))
utils.save_image(valid_annotations[itr].astype(np.uint8), FLAGS.logs_dir, name="gt_" + str(5+itr))
utils.save_image(pred[itr].astype(np.uint8), FLAGS.logs_dir, name="pred_" + str(5+itr))
print("Saved image: %d" % itr)
if __name__ == '__main__':
out_pb_path = "../checkpoints/frozen_model.pb"
test_dir = "data/cnews/cnews.test.txt"
freeze_graph_test2(pb_path=out_pb_path,test_path=test_dir)