TorchScript教程
TorchScript介绍
TorchScript是Pytorch模型(继承自nn.Module)的中间表示,可以在像C++这种高性能的环境中运行。
这个教程主要涵盖以下内容:
- Pytorch基础
- Modules
- 定义forward函数
- 将modules组合进modules
- 定义转换pytorch的modules到TorchScript的方法,一个高性能的部署运行:
- 追踪一个已存在的module
- 使用scripting来直接编译一个module
- 如何结合两者方法
- 保存和加载TorchScript的modules
Pytorch模型编写基础
一个Module定义基础包含以下内容:
- 一个构造器,准备一些初始化参数
- 一个Parameters和sub-Modules的集合,这些将会在构造器内初始化,如torch.nn.Linear之类的
- 一个forward函数,当module被调用时调用
如下一个示例模型:
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.dg = MyDecisionGate()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
my_cell = MyCell()
print(my_cell)
print(my_cell(x, h))
#输出如下
MyCell(
(dg): MyDecisionGate()
(linear): Linear(in_features=4, out_features=4, bias=True)
)
(tensor([[-0.0657, 0.1869, 0.8526, 0.3125],
[ 0.6072, 0.7615, 0.6674, 0.7230],
[ 0.0875, 0.7908, 0.6205, 0.5743]], grad_fn=<TanhBackward>), tensor([[-0.0657, 0.1869, 0.8526, 0.3125],
[ 0.6072, 0.7615, 0.6674, 0.7230],
[ 0.0875, 0.7908, 0.6205, 0.5743]], grad_fn=<TanhBackward>))
其中torch.nn.Linear是一个Module都继承了Module,打印Module可以给出其层级。可以看到其子类和参数。
其中,输出中有grad_fn,允许我们潜在的计算导数。
这里我们加入了MyDecisionGate,这个modules使用了控制流control flow,包含了循环和if判断。
许多框架在给定完整程序的情况下计算符号的导数,但是Pytorch中在计算的时候记录,在计算中向后回放。这样就不用为所有的构造定义导数函数。
TorchScript基础
TorchScript提供了工具来捕捉模型的定义,即使在轻量灵活和动态的Pytorch中。我们通过tracing 的方法。
如下:
class MyCell(torch.nn.Module):
def __init__(self):
super(MyCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)
# 输出
MyCell(
original_name=MyCell
(linear): Linear(original_name=Linear)
)
其中,我们调用了torch.jit.trace,传入Module和符合的示例输入。
它会调用Moduel并将操作记录下来,当Module运行时记录下操作,然后创建torch.jit.ScriptModule的实例(其中TracedModule是一个实例)
TorchScript记录下模型定义在中间表示中(Intermediate Representation (IR)),在深度学习中通常被称为graph,我们可以打印.graph属性。
print(traced_cell.graph)
# 输出
graph(%self.1 : __torch__.MyCell,
%input : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
%h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
%21 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
%23 : Tensor = prim::CallMethod[name="forward"](%21, %input)
%14 : int = prim::Constant[value=1]() # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
%15 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%23, %h, %14) # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
%16 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%15) # /var/lib/jenkins/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:188:0
%17 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = pr