背景
- PyTorch的主要接口是Python语言。虽然Python是许多需要动态和易于迭代的场景的首选语言,但同样有很多情况下,Python的这些属性恰好是不利的。在生产环境中,需要保证低延迟和其它严格的要求,而对于生产场景,C++通常是首选语言,通常C++会被绑定到Java或Go当中;
- 第一种方法是在PyTorch中直接使用C++编写PyTorch的前端,而不是通常情况下使用Python来编写PyTorch的前端,实现模型定义、训练和评估以及模型的部署;
- 第二种方法是使用Python编写PyTorch的前端,并且实现上述功能;
- 众所周知,Python相对于C++在不考虑执行效率的情况下具有很多优势,本文不会讨论这方面的问题。因此,如果可以使用Python编写前端,实现模型定义、训练和评估,而将模型的部署交由C++实现,则可以最大化目标,最快地获得模型以及部署高效的模型。
- 概念转换成具体的方案,将Python在PyTorch下训练得到的模型文件转化成C++可以加载和执行的模型文件,并且自此以后不再依赖于Python;
- PyTorch模型从Python到C ++之旅由Torch Script实现,Torch Script是PyTorch模型的一种表示,并且可以由Torch Script的编译器理解、编译和序列化。
环境
- 安装PyTorch的版本为1.0及以上;
- 安装C++版本的LibTorch,LibTorch发行版包含一组共享库,头文件和CMake构建配置文件;
- 安装Intel所提供的MKL-DNN库,Caffe依赖此库,编译得到的可执行程序会依赖其中的
libmklml
和libiomp5
动态链接库,若无此库在执行程序时会有错误产生,安装后将这两个动态链接库拷贝至LibTorch的lib目录下;
Mac OS 报错信息如下:
dyld: Library not loaded: @rpath/libmklml.dylib> dyld: Library not loaded: @rpath/libmklml.dylib
Referenced from: ******
Reason: image not found
- 安装CMake。
将PyTorch模型转化为Torch Script的两种方法
- 如果需要C++使用PyTorch的模型,就必须先将PyTorch模型转化为Torch Script;
- 目前有两种方法,可以将PyTorch模型转化为Torch Script:
- 第一种方法是tracing。该方法通过将样本输入到模型中,并对该过程进行推断从而捕获模型的结构,并记录该样本在模型中的控制流。该方法适用于模型中较少使用控制流的模型;
- 第二种方法是向模型中添加显式的注释,使得Torch Script编译器可以直接解析和编译模型的代码,受Torch Script强加的约束。该方法适用于使用特定控制流的模型,。
利用Tracing将模型转换为Torch Script
通过tracing的方法将PyTorch的模型转换为Torch Script,则必须将模型的实例以及样本输入传递给torch.jit.trace
方法。这样会生成一个torch.jit.ScriptModule
对象,模型中的forward
方法中用预先嵌入了模型推断的跟踪机制:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# ScriptModule
output = traced_script_module(torch.ones(1, 3, 224, 224))
利用注释将模型转换为Torch Script
- 在某些情况下,如模型采用特定形式的控制流(
if...else...
),可以使用注释的方法将模型转化为Torch Script; - 此模块的
forward
方法使用依赖于输入的控制流,因此它不适合利用tracing的方法生成Torch Script。可以通过继承torch.jit.ScriptModule
并将@torch.jit.script_method
标注添加到模型的forward中的方法,来将模型转换为ScriptModule
。
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand