在C++中加载TorchScript模型
本教程已更新为可与PyTorch 1.2一起使用
顾名思义,PyTorch的主要接口是Python编程语言。尽管Python是合适于许多需要动态性和易于迭代的场景,并且是首选的语言,但同样的,在 许多情况下,Python的这些属性恰恰是不利的。后者通常适用的一种环境是要求生产-低延迟和严格部署。对于生产场景,即使只将C ++绑定到Java, Rust或Go之类的另一种语言中,它也是经常选择的语言。以下各段将概述PyTorch提供的从现有Python模型到可以完全从C ++加载和执行的序 列化表示形式的路径,而无需依赖Python。
步骤1:将PyTorch模型转换为Torch脚本
PyTorch模型从Python到C ++的旅程由Torch Script启动,Torch Script是PyTorch模型的一种表示形式,可以由Torch Script编译器理解, 编译和序列化。如果是从使用vanilla“eager” API编写的现有PyTorch模型开始的,则必须首先将模型转换为Torch脚本。在最常见的情况 下(如下所述),这只需要花费很少的功夫。如果已经有了Torch脚本模块,则可以跳到本教程的下一部分。
有两种将PyTorch模型转换为Torch脚本的方法。第一种称为跟踪,一种机制,其中通过使用示例输入对模型的结构进行一次评估,并记录这些 输入在模型中的流量,从而捕获模型的结构。这适用于有限使用控制流的模型。第二种方法是在模型中添加显式批注,以告知Torch Script编 译器可以根据Torch Script语言施加的约束直接解析和编译模型代码。
提示:可以在官方Torch脚本参考中找到有关这两种方法的完整文档,以及使用方法的进一步指导。
方法1:通过跟踪转换为Torch脚本
要将PyTorch模型通过跟踪转换为Torch脚本,必须将模型的实例以及示例输入传递给torch.jit.trace函数。这将产生一个torch.jit.ScriptModule 对象,该对象的模型评估痕迹将嵌入模块的forward方法中:
import torch
import torchvision
模型的一个实例.
model = torchvision.models.resnet18()
通常会提供给模型的forward()方法的示例输入。
example = torch.rand(1, 3, 224, 224)
使用torch.jit.trace
来通过跟踪生成torch.jit.ScriptModule
traced_script_module = torch.jit.trace(model, example)
现在可以对跟踪的ScriptModule进行评估,使其与常规PyTorch模块相同:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=)
方法2:通过注释转换为Torch脚本
在某些情况下,例如,如果模型采用特定形式的控制流,则可能需要直接在Torch脚本中编写模型并相应地注释模型。例如,假设具有以下 vanilla Pytorch模型:
import torch
class MyModule(torch.nn.Module):
def init(self,