Chapter 8 WORKING WITH DEEP LEARNING FRAMEWORKS
使用Python API,使用TensorFlow、Caffe或ONNX兼容框架构建的现有模型可提供提供的解析器构建TensorRT引擎。 Python API还支持以NumPy兼容格式存储层权重的框架,例如PyTorch。
8.1 支持的框架操作
以下列表描述了Caffe或TensorFlow框架和ONNX TensorRT解析器中支持的操作:
-
caffe
这些是Caffe框架解析器支持的操作:
‣ Convolution
‣ Pooling
‣ InnerProduct
‣ SoftMax
‣ ReLU, TanH, and Sigmoid
‣ LRN
‣ Power
‣ ElementWise
‣ Concatenation
‣ Deconvolution
‣ BatchNormalization
‣ Scale
‣ Crop
‣ Reduction
‣ Reshape
‣ Permute
‣ Dropout -
Tensorflow:
这些是TensorFlow框架支持的操作:
‣ Placeholder
‣ Const
‣ Add, Sub, Mul, Div, Minimum and Maximum
‣ BiasAdd
‣ Negative, Abs, Sqrt, Rsqrt, Pow, Exp and Log
NvUffParser仅支持const节点的Neg,Abs,Sqrt,Rsqrt,Exp和Log
‣ FusedBatchNorm
‣ ReLU, TanH, and Sigmoid
‣ SoftMax
如果TensorFlow SoftMax op的输入不是NHWC,TensorFlow将自动插入具有非常量置换的转置层,导致UFF转换器失败。因此,建议使用常量置换手动将SoftMax输入转置到NHWC。
‣ Mean
‣ ConcatV2
‣ Reshape
‣ Transpose
‣ Conv2D
‣ DepthwiseConv2dNative
‣ ConvTranspose2D
‣ MaxPool
‣ AvgPool
‣ 如果后面跟着Conv2D,DepthwiseConv2dNative,MaxPool和AvgPool中某一个TensorFlow层,则支持Pad。 -
ONNX:
由于ONNX解析器是一个开源项目,因此可以在GitHub:ONNX TensorRT中找到有关支持的操作的最新信息。这些是ONNX框架支持的操作:
‣ Abs
‣ Add
‣ AveragePool
‣ BatchNormalization
‣ Ceil
‣ Clip
‣ Concat
‣ Constant
‣ Conv
‣ ConvTranspose
‣ DepthToSpace
‣ Div
‣ Dropout
‣ Elu
‣ Exp
‣ Flatten
‣ Floor
‣ Gemm
‣ GlobalAveragePool
‣ GlobalMaxPool
‣ HardSigmoid
‣ Identity
‣ InstanceNormalization
‣ LRN
‣ LeakyRelu
‣ Log
‣ LogSoftmax
‣ MatMul
‣ Max
‣ MaxPool
‣ Mean
‣ Min
‣ Mul
‣ Neg
‣ PRelu
‣ Pad
‣ Pow
‣ Reciprocal
‣ ReduceL1
‣ ReduceL2
‣ ReduceLogSum
‣ ReduceLogSumExp
‣ ReduceMax
‣ ReduceMean
‣ ReduceMin
‣ ReduceProd
‣ ReduceSum
‣ ReduceSumSquare
‣ Relu
‣ Reshape
‣ Selu
‣ Shape
‣ Sigmoid
‣ Size
‣ Softmax
‣ Softplus
‣ SpaceToDepth
‣ Split
‣ Squeeze
‣ Sub
‣ Sum
‣ Tanh
‣ TopK
‣ Transpose
‣ Unsqueeze
‣ Upsample
8.2 使用Tensorflow
有关将TensorRT与TensorFlow模型一起使用的信息,请参阅end_to_end_tensorflow_mnist Python示例。
8.2.1 冻结TensorFlow图
要使用命令行UFF实用程序,必须冻结TensorFlow图并将其另存为.pb文件。有关更多信息,请参阅:
‣ A Tool Developer’s Guide to TensorFlow Model Files: Freezing
‣ Exporting trained TensorFlow models to C++ the RIGHT way!
8.2.2 冻结Keras模型
您可以使用以下示例代码冻结Keras模型。
from keras.models import load_model
import keras.backend as K
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.training import saver as saver_lib
def convert_keras_to_pb(keras_model, out_names, models_dir,model_filename):
model = load_model(keras_model)
K.set_learning_phase(0)
sess = K.get_session()
saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
checkpoint_path = saver.save(sess, 'saved_ckpt', global_step=0,latest_filename='checkpoint_state')
graph_io.write_graph(sess.graph, '.', 'tmp.pb')
freeze_graph.freeze_graph('./tmp.pb', '',False, checkpoint_path, out_names,"save/restore_all", "save/Const:0",models_dir+model_filename, False, "")
8.2.3 将冻结图转换为UFF
您可以使用以下示例代码将.pb冻结图转换为.uff格式文件。
convert-to-uff input_file [-o output_file] [-O output_node]
您可以列出TensorFlow层:
convert-to-uff input_file -l
8.2.4 使用TensorFlow RNN权重
本节提供有关TensorFlow权重及其存储格式的信息。此外,以下部分将指导您如何从TensorFlow接近和解密RNN权重。
8.2.4.1 TensorRT中支持的TensorFlow RNN单元
8.2.5 使用Graph Surgeon API预处理TensorFlow图
Graph Surgeon API(也称为graphurgeon)允许您转换TensorFlow图。其功能大致分为两类:
- 检索
-搜索功能允许您在TensorFlow图中查找节点。 - 操作
- 操作函数允许您修改,添加或删除节点。
使用graphurgeon,您可以将某些节点(或节点集)标记为图中的插件节点。这些插件既可以是TensorRT附带的插件,也可以是您编写的插件。有关更多信息,请参阅使用自定义图层扩展TensorRT。
如果您正在编写插件,请参阅有关如何实现IPluginExt和IPluignCreator类以及注册插件的详细信息,请参阅使用自定义图层扩展TensorRT。
以下代码片段说明了如何使用graphurgeon将TensorFlow Leaky ReLU操作映射到TensorRT Leaky ReLU插件节点。
import graphsurgeon as gs
lrelu_node = gs.create_plugin_node(name=”trt_lrelu”, op=”LReLU_TRT”,negSlope=0.2)
namespace_plugin_map = { “tf_lrelu” : lrelu_node }
# Transform TensorFlow graph using graphsurgeon and save to UFF
dynamic_graph = gs.DynamicGraph(tf_lrelu.graph)
dynamic_graph.collapse_namespaces(namespace_plugin_map)
# Run UFF converter using new graphdef
uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), ["trt_lrelu"],output_filename=”test_lrelu.uff”, text=True)
在上面的代码中,create_plugin_node方法中的op字段应该与注册的插件名称匹配。这使得UFF解析器能够使用该字段在插件注册表中查找插件,以将插件节点插入网络中。
有关一个有效的graphurgeon示例,请参阅sampleUffSSD for C ++。
有关graphurgeon API的更多详细信息,请参阅Graph Surgeon API。
8.3 使用PyTorch和其他框架
将TensorRT与PyTorch和其他框架一起使用涉及使用TensorRT API复制网络架构,然后从PyTorch(或具有NumPy兼容权重的任何其他框架)复制权重。有关将TensorRT与PyTorch模型一起使用的更多信息,请参阅network_api_pytorch_mnist Python示例。