torch2trt源码解析与使用之核心原理初探(一)
torch2trt源码解析与使用之核心原理初探(二)
为什么使用torch2trt?
在使用pytorch作为训练框架,后续需要将模型部署到嵌入式或者实时应用场景时,使用Nvidia的卡的话,Nvidia官方提供了一套tensorrt的加速方案。但是除tf和caffe外,其他框架不能直接使用。常用的torch–>trt的方案是torch–>onnx–>trt。GitHub上torch2trt这个项目完成了从torch–>trt的直接转换,接下来主要解析其源码看看这个项目是如何实现的。
核心思想
官方项目主页–How does it work?–对其工作原理有简单介绍。我会结合源码解释。其转换原理的核心在于构建一个CONVERTERS[method]字典,将pytorch module中对应的pytorch方法计算替换为Tensorrt的方法,并将这种method名称作key值,Tensorrt的对应method实现函数作为value存储在字典中。
torch–>trt的关键在于INetwork的构建
首先从torch模型转换到trt模型的核心是要完成trt的INetwork的构建。作者通过“绑定”每一个module的forward执行,到每一个module的converter,pytorch的model在预测时会具体到每一个module的forward函数,再到其绑定的converter以实现tensorrt的INetwork的构建。首先,以conv2d的模块为例看一看convert_conv2d实现里哪些功能。
先忽略ctx这个参数,以后我们会知道它是一个ConversionContext的自定义类。convert_con2d的前面一部分是在获取conv2d这一层的输入以及卷机所需的卷积核大小,padding,stride,dilation等参数,可以看到这些参数的获取是module的属性,而这个module在这里是pytorch型的module。关键的一句话在于:
ctx.network即为tensorrt的builder创建的网络,network = builder.create_network()。
除了函数本身外,还看到convert_conv2d被tensorrt_converter装饰,查看这个装饰器,
可以看到,正是在此处构建了上文说到的CONVERTERS字典,而字典的真正用处之后会解释到。从字典的keys和values可以看到,pytorch模块的“module_name”,module本身,(qual_name这里是指forward)以及重要的converter都存储在了字典中。而字典是通过装饰器构建的,所以在程序导入时就已经直接完成了注册。
convert_conv2d何时显式执行?
虽然CONVERTERS字典式自动建立的,但是convert_conv2d函数本身需要被显式调用,而如果去查看torch2trt这个函数本身,
module作为pytorch的模型只跑了一遍前向,tensorrt的network遍完成了构建。回顾之前说过的moduel每当执行到forward处就会自动构建trt的network,而从convert_conv2d中可以看到,必须显式调用才可构建,所谓的“自动构建”究竟是如何完成的?