深度学习-性能优化3:TF-TRT的简单使用
1、在训练的时候保存为pb模型
import tensorflow as tf
sess = tf.Session()
matrix_1 = tf.constant([3., 3.], name='input')
add = tf.add(matrix_1, matrix_1, name='output')
sess.run(add)
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 保存模型到目录下的model文件夹中
with tf.gfile.FastGFile('./model/tensorflow_matrix_graph.pb',mode='wb') as f:
f.write(output_graph_def.SerializeToString())
sess.close()
2、在训练的时候保存为pb模型
saver = tf.train.Saver()
last_ckpt = saver.save(sess, 'model/model.ckpt')
3.ckpt转为pb模型---冻结
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
刚开始的时候老是出现
assert d in name_to_node_map, “%s is not in graph” % d
AssertionError: A is not in graph
使用下列for 循环打印出所有节点
saver = tf.train.Saver()
saver.restore(self.sess, tf.train.latest_checkpoint(saved_model_path))
for op in tf.get_default_graph().get_operations():
print(op.name)
同理接下来也会使用,如下代码,获取tensor的节点名字
for op in graph.get_operations():
print(op.name,op.values())
4.ckpt不转换直接冻结图,直接使用
import os
try:
os.chdir(os.path.join(os.getcwd(), 'examples/classification'))
print(os.getcwd())
except:
pass
#%%
from PIL import Image
import sys
import os
import urllib
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
MODEL = 'inception_v1'
CHECKPOINT_PATH = 'inception_v1.ckpt'
NUM_CLASSES = 1001
LABELS_PATH = './data/imagenet_labels_%d.txt' % NUM_CLASSES
IMAGE_PATH = './data/dog-yawning.jpg'
frozen_graph, input_names, output_names = build_classification_graph(
model=MODEL,
checkpoint=checkpoint_path,
num_classes=NUM_CLASSES
)
trt_graph = trt.create_inference_graph(
input_graph_def=frozen_graph,
outputs=output_names,
max_batch_size=1,
max_workspace_size_bytes=1 << 25,
precision_mode='FP16',
minimum_segment_size=50
)
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config)
tf.import_graph_def(trt_graph, name='')
tf_input = tf_sess.graph.get_tensor_by_name(input_names[0] + ':0')
tf_output = tf_sess.graph.get_tensor_by_name(output_names[0] + ':0')
image = Image.open(IMAGE_PATH)
plt.imshow(image)
width = int(tf_input.shape.as_list()[1])
height = int(tf_input.shape.as_list()[2])
image = np.array(image.resize((width, height)))
output = tf_sess.run(tf_output, feed_dict={
tf_input: image[None, ...]
})
scores = output[0]
with open(LABELS_PATH, 'r') as f:
labels = f.readlines()
top5_idx = scores.argsort()[::-1][0:5]
for i in top5_idx:
print('(%3f) %s' % (scores[i], labels[i]))
tf_sess.close()
def build_classification_graph(model, checkpoint, num_classes):
global NETS, input_name, output_name
net = NETS[model]
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with tf.Graph().as_default() as tf_graph:
with tf.Session(config=tf_config) as tf_sess:
tf_input = tf.placeholder(tf.float32, [None, net.input_height, net.input_width, 3],
name=input_name)
tf_preprocessed = net.preprocess(tf_input)
with slim.arg_scope(net.arg_scope()):
tf_net, tf_end_points = net.model(tf_preprocessed, is_training=False,
num_classes=num_classes)
tf_output = net.postprocess(tf_net, name=output_name)
# load checkpoint
tf_saver = tf.train.Saver()
tf_saver.restore(save_path=checkpoint, sess=tf_sess)
# freeze graph
frozen_graph = tf.graph_util.convert_variables_to_constants(
tf_sess,
tf_sess.graph_def,
output_node_names=[output_name]
)
# remove relu 6
frozen_graph = convert_relu6(frozen_graph)
return frozen_graph, [input_name], [output_name]
其他参考文章
https://github.com/yazone/ai_learning_pathgithub.com https://github.com/yazone/ai_learning_path/tree/master/04.AI%E8%90%BD%E5%9C%B0%E5%AE%9E%E8%B7%B5/04.%E6%80%A7%E8%83%BD%E4%BC%98%E5%8C%96/nvidia%E6%80%A7%E8%83%BD%E4%BC%98%E5%8C%96github.com可以关注我们的微信公众号持续会有AI系统学习的内容