需要使用Pytorch的版本为1.0及以上:
conda install pytorch-nightly -c pytorch
第一步:将Pytorch模型转化为Torch script
Torch Script是连接C++和Python的桥梁,Pytorch模型的表示,可以被Torch Script编译器理解,编译和序列化.
如果想要C++使用Pytorch的模型,就必须先将Pytorch模型转化为Torch Script.在大多数情况下,这样的工作量都比较小,如果已经有了模型的Torch Script,那么下面的内容就不需要看了.
有两种方法,可以将Pytorch模型转化为Torch Script.
第一个方法是tracing.该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构.并记录该样本在模型中的flow.该方法适用于模型中很少使用控制flow的模型.
第二个方法就是向模型添加显式注释,通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
利用Tracing将模型转换为Torch Script
要通过tracing来将PyTorch模型转换为Torch脚本,必须将模型的实例以及样本输入传递给torch.jit.trace函数.
这将生成一个torch.jit.ScriptModule对象,并在模块的forward方法中嵌入模型评估的跟踪:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
现在,跟踪的ScriptModule可以与常规PyTorch模块进行相同的计算:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
通过标注将Model转换为Torch Script
在某些情况下,例如,如果模型使用特定形式的控制流,如果想要直接在Torch脚本中编写模型并相应地标注模型。例如,假设有以下普通的 Pytorch模型:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.<