TorchScript是什么
TorchScript 可以看作Python的一个子集,主要的应用场景是把Python/PyTorch代码转换成等价的C++代码从而提高深度学习模型在线上生产环境部署的运行效率。Python代码会被编译成TorchScript编译器可以理解的一种格式(ScriptModule),C++的生产环境可以载入该格式的文件并用内置的JIT来执行对应的代码。
TorchScript提供了两种方法来把Python代码转换成TorchScript representation,分别为:
- Tracing
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via **tracing**.
traced_script_module = torch.jit.trace(model, example)
- Scripting
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
在这里我们会从源代码层面去分析TorchScript是如何实现上面这两种方案的,并且每个方案自身的限制有哪些,来源是什么。
Tracing
如上面的例子所示,
tracing
的应用场景是当你已经有了一个nn.Module (PyTorch里面定义神经网络的基本单元之一),你可以随便构造一个输入,然后告诉tracer:
- 把这个网络运行一遍
- 在这个网络运行过程中,所有操作都记录下来
- 记录下来的操作自然可以被复原成代码
非常简单直接,也跟这个名字非常吻合——在这个模型运行的过程中,有一台跟踪者在跟踪每一步的执行然后记录,因而得名tracer。
具体的实现方式,是通过在Operator的代码(也是C++)里面加上额外的追踪代码
[1]
[2]
。因为PyTorch的Module还有Operator为了执行效率本来就是C++的代码bind到了Python环境,所以在运行这个网络的过程之中自然会执行到Operator的C++源码,而其中就顺带执行了追踪代码。在追踪代码里面,每执行一个operator,就会往当前TracingState(定义成一个线程局部变量)里面的graph加入一个node。所有代码执行完毕,每一步的操作就会以一个Computation Graph里的某个节点的形式被保存下来。
(由于Python是单线程的,所以整个Computation Graph代码的执行顺序也是线性的,不用担心多线程带来的混乱。)
但是
Tracing
有如下限制:
- 这个神经网络的执行不能是data dependent的——也就说,不能有if else或者不等长的loop
- 只支持Tensor操作,而不支持其他操作
跟实现方法一对照,我们很容易可以理解为什么有这些限制
- 追踪出来的Computation Graph是静态的。如果模型是data dependent的话,那么不同的输入所追踪出来的graph是不一样的(因为graph本身不能表示if else或者loop等操作)。如果硬是套用,就会带来不正确的结果
- 追踪代码是跟Tensor Operator绑定在一起的,所以不走Operator的Python 逻辑是没法被追踪的
Scripting
Scripting
,从上面的例子来看,似乎跟
Tracing
区别不大,但是其实现方法非常不一样。概括而言,
scripting
是通过把Python的源代码解析成
语法树
,然后转化成C++可执行代码来实现的。
因为是直接编译源代码,除了应用在nn.Module上面,script也可以直接被用来annotate一般的python class/function,并且可以支持条件语句等
tracing
不能处理的情况。但这也有缺点:现在的实现只能支持编译Python语法特定子集的代码,因此存在一部分的代码在
tracing
里可以work但在
scripting
这边由于编译器的限制不支持。
如下是
scripting
的实现细节(根据Python源代码的来源不同会有差别):
#1 从
最简单的
开始:如果代码来源是Python函数(def foo()),那么大致流程如下:
- 读取Python源代码文本,将其解析成AST (抽象语法树)(
代码
) - 取得一个CompilationUnit对象(
代码
) - 通过CompilationUnit::define方法把AST转换成一个C++可执行的函数对象StrongFunctionPtr (
script_compile_function
)
#2 如果需要转换的代码是一个类(class Foo(object)),大致流程跟Python函数的case差不多,不过有一些限制
和
细微差别:
- 不支持继承——只支持直接继承object的类
- Class body里面只能定义methods (e.g. def foo()),不能定义其他代码
- AST解析的时候需要recursively把这个类的方法也都分别解析成对应的AST
- 在调用CompilationUnit::define的时候需要传一个"self"参数以表示这个是一个类里面的成员函数而不是全局函数
#3 如果需要转换的代码来自于一个nn.Module (PyTorch里面用来定义神经网络的类)的
实例
,大致流程会相对复杂一点:
- 首先,需要把这个nn.Module的instance映射到一个ConcreteModuleType (为什么要这么做的原因在
这里
有记载) - 遍历找到这个nn.Module的所有方法,方法
和
#2类似但是会去掉overload函数以及被标记成无用的函数(
代码
),然后按类似的方法调用CompilationUnit编译 - 拷贝所有的attributes/parameters
- 遍历找到所有的children nn.Module并编译
两者取舍
个人观点(也包括跟在PyTorch组工作的Engineer讨论得出的结论)
- 大部分情况model只有tensor operation,就直接无脑
tracing - 带control-flow (if-else, for-loop) 的,上
scripting - 碰上
scripting
不能handle的语法,要么重写,要么把
tracing
和
scripting
合起来用(比如说只在有control-flow的代码用
scripting
,其他用
tracing
)