saved_model 转 tensorrt 的 plan 模型

本文脚本运行环境及依赖:

  • python 3.5.2
  • tnsorflow 1.14.0
  • tensorrt 5.0.2.6
  • uff 0.5.5

收到一个需要将 tensorflow 下的 saved_model 模型转成 tensorrt 模型后运行在 tensorrtserver 的任务。tensorrtserver 只支持 tensorrt 下的 plan 模型,所以第一步需要将 saved_ model 模型转成 plan 模型。本文总结了将 tensorflow 下的 saved 模型转成 tensorrt 下 plan 模型的过程。plan 模型是可以直接运行在 tensorrtserver 下的模型文件。saved_model 是 tensorflow 下模型持久化的格式之一。

如下的代码展示了一个矩阵相乘的 saved_model 的文件的生成:

import tensorflow as tf
import numpy as np

input0 = tf.placeholder(tf.float32, [None, None], "input0")
input1 = tf.placeholder(tf.float32, [None, None], "input1")

b = tf.Variable(2.0,name='b')
#矩阵相乘
output = tf.matmul(input0, input1, name="matmul")

with tf.compat.v1.Session() as sess:
    sess.run(tf.global_variables_initializer())
    v = sess.run([output], feed_dict={input0: np.ones([3, 2], np.float32), input1: np.ones([2, 3], np.float32)})
    
    print(v)
    print(b)
    
    #变量及输出的 tensor 信息
    inp0 = tf.saved_model.utils.build_tensor_info(input0)
    inp1 = tf.saved_model.utils.build_tensor_info(input1)
    out = tf.saved_model.utils.build_tensor_info(output)
    #输入输出签名
    sign = tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'input0': inp0, 'input1': inp1},
        outputs={'output': out},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

    builder = tf.saved_model.builder.SavedModelBuilder("export")
    #模型要被 tensorrtserver 运行,必须以如下的方式保存
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sign },
      main_op=tf.tables_initializer(),
      strip_default_attrs=True)

    builder.save()
    print ("export end!!")

执行之前需要清空或者删除 export 目录,生成的文件目录结构如下:

-export
----saved_model.pb
----variables
-------variables.data-00000-of-00001
-------variables.index

将 saved_model 模型转成 plan 需要分成两个步骤,第一步是生成 frozen_graph 模型,再将 frozen 模型转成 plan 模型。如下的代码可以将 export 目录中的模型转换成名称为 frozen_graph.pb 的 frozen 模型,模型保存在脚本同路径。注意其中 out_name 的配置要和 saved_model 中的输出节点的 op 名称一致,如果不知道,最好找到提供 saved_mode的人得到准确的答案,不要在这上面浪费时间,网上有文章可以找到模型中的 op 列表,但是不不能指出具体是哪个。

from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
from tensorflow.python.training import saver as saver_lib
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
import os

def convert_to_frozen(sess, models_dir, model_filename, out_name):
	try:
        #加载 saved_model
		tf.saved_model.loader.load(sess, [tag_constants.SERVING], models_dir)
		saver = saver_lib.Saver(allow_empty=True)
		checkpoint_path = saver.save(sess, 'saved_ckpt', global_step=0, latest_filename='checkpoint_state')
		graph_io.write_graph(sess.graph, '.', 'tmp.pb')
		
		model_path = model_filename
		freeze_graph.freeze_graph('./tmp.pb', '',
	                          	False, checkpoint_path, out_name,
	                          	"save/restore_all", "save/Const:0",
	                          	model_path , False, "")
	finally:
		#移除中间文件
		try: os.remove("tmp.pb") 
		except: pass
		try: os.remove("checkpoint_state")
		except: pass
		try: os.remove("saved_ckpt-0.meta")
		except: pass
		try: os.remove("saved_ckpt-0.index")
		except: pass
		try: os.remove("saved_ckpt-0.data-00000-of-00001")
		except: pass
	
	return model_path, out_name
	
models_dir = "export/"
out_name = "matmul"
with tf.Session() as sess:
	model_path, out_name = convert_to_frozen(sess, models_dir, "frozen_graph.pb", out_name)
	
	print("freeze model finished: ", model_path)
	print(out_name)

调用如下的脚本可以将 frozen 模型转成 plan 模型(下面的脚本需要在 GPU 的机器上执行)

import uff
import tensorrt as trt
from tensorrt.legacy.parsers import uffparser
#from tensorrt.parsers import uffparser

model_path = "frozen_graph.pb"
out_name = "matmul"

#这一步其实包含了将 saved_model 转成 uff 文件的过程
uff_model = uff.from_tensorflow_frozen_model(model_path, [out_name])
parser = uffparser.create_uff_parser()

G_LOGGER = trt.legacy.infer.ConsoleLogger(trt.legacy.infer.LogSeverity.INFO)
engine = trt.legacy.utils.uff_to_trt_engine(G_LOGGER, uff_model,
                 parser, 1,
                 1<<30 ,datatype=trt.legacy.infer.DataType.FLOAT)

trt.legacy.utils.cwrite_engine_to_file(model_path + 'model.plan',engine.serialize())

注意 tensorrt 的版本,如果提示找不到 tensorrt.legacy 模型,就将其删除,将 import 行改成注释中的那样,并将在脚本中用到 legacy 的地方删除。legacy 是遗产的意思,legacy 里面的内容是旧版的 tensorrt 的模块,新版的 tensorrt 将其作为遗产保留下来了。所以不同的 tensorrt 的版本,脚本的代码不同。

以上的第一第二个脚本也可以通过执行如下命令完成:

convert-to-uff
freeze_graph

这两个命令在安装完 tensorflow 和 uff 以及 tensorrt 后就会提供,在系统命令行执行。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值