PyTorch模型转换为TorchScript实战指南

TorchScript是什么

TorchScript 可以看作Python的一个子集,主要的应用场景是把Python/PyTorch代码转换成等价的C++代码从而提高深度学习模型在线上生产环境部署的运行效率。Python代码会被编译成TorchScript编译器可以理解的一种格式(ScriptModule),C++的生产环境可以载入该格式的文件并用内置的JIT来执行对应的代码。

TorchScript提供了两种方法来把Python代码转换成TorchScript representation,分别为:

  • Tracing
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via **tracing**.
traced_script_module = torch.jit.trace(model, example)
  • Scripting
class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_module = MyModule(10,20)
sm = torch.jit.script(my_module)

在这里我们会从源代码层面去分析TorchScript是如何实现上面这两种方案的,并且每个方案自身的限制有哪些,来源是什么。

Tracing

如上面的例子所示,
tracing
的应用场景是当你已经有了一个nn.Module (PyTorch里面定义神经网络的基本单元之一),你可以随便构造一个输入,然后告诉tracer:

  • 把这个网络运行一遍
  • 在这个网络运行过程中,所有操作都记录下来
  • 记录下来的操作自然可以被复原成代码

非常简单直接,也跟这个名字非常吻合——在这个模型运行的过程中,有一台跟踪者在跟踪每一步的执行然后记录,因而得名tracer。

具体的实现方式,是通过在Operator的代码(也是C++)里面加上额外的追踪代码
[1]
[2]
。因为PyTorch的Module还有Operator为了执行效率本来就是C++的代码bind到了Python环境,所以在运行这个网络的过程之中自然会执行到Operator的C++源码,而其中就顺带执行了追踪代码。在追踪代码里面,每执行一个operator,就会往当前TracingState(定义成一个线程局部变量)里面的graph加入一个node。所有代码执行完毕,每一步的操作就会以一个Computation Graph里的某个节点的形式被保存下来。

(由于Python是单线程的,所以整个Computation Graph代码的执行顺序也是线性的,不用担心多线程带来的混乱。)

但是
Tracing
有如下限制:

  • 这个神经网络的执行不能是data dependent的——也就说,不能有if else或者不等长的loop
  • 只支持Tensor操作,而不支持其他操作

跟实现方法一对照,我们很容易可以理解为什么有这些限制

  • 追踪出来的Computation Graph是静态的。如果模型是data dependent的话,那么不同的输入所追踪出来的graph是不一样的(因为graph本身不能表示if else或者loop等操作)。如果硬是套用,就会带来不正确的结果
  • 追踪代码是跟Tensor Operator绑定在一起的,所以不走Operator的Python 逻辑是没法被追踪的
Scripting

Scripting
,从上面的例子来看,似乎跟
Tracing
区别不大,但是其实现方法非常不一样。概括而言,
scripting
是通过把Python的源代码解析成
语法树
,然后转化成C++可执行代码来实现的。

因为是直接编译源代码,除了应用在nn.Module上面,script也可以直接被用来annotate一般的python class/function,并且可以支持条件语句等
tracing
不能处理的情况。但这也有缺点:现在的实现只能支持编译Python语法特定子集的代码,因此存在一部分的代码在
tracing
里可以work但在
scripting
这边由于编译器的限制不支持。

如下是
scripting
的实现细节(根据Python源代码的来源不同会有差别):

#1 从
最简单的
开始:如果代码来源是Python函数(def foo()),那么大致流程如下:

  • 读取Python源代码文本,将其解析成AST (抽象语法树)(
    代码
  • 取得一个CompilationUnit对象(
    代码
  • 通过CompilationUnit::define方法把AST转换成一个C++可执行的函数对象StrongFunctionPtr (
    script_compile_function

#2 如果需要转换的代码是一个类(class Foo(object)),大致流程跟Python函数的case差不多,不过有一些限制


细微差别:

  • 不支持继承——只支持直接继承object的类
  • Class body里面只能定义methods (e.g. def foo()),不能定义其他代码
  • AST解析的时候需要recursively把这个类的方法也都分别解析成对应的AST
  • 在调用CompilationUnit::define的时候需要传一个"self"参数以表示这个是一个类里面的成员函数而不是全局函数

#3 如果需要转换的代码来自于一个nn.Module (PyTorch里面用来定义神经网络的类)的
实例
,大致流程会相对复杂一点:

  • 首先,需要把这个nn.Module的instance映射到一个ConcreteModuleType (为什么要这么做的原因在
    这里
    有记载)
  • 遍历找到这个nn.Module的所有方法,方法

    #2类似但是会去掉overload函数以及被标记成无用的函数(
    代码
    ),然后按类似的方法调用CompilationUnit编译
  • 拷贝所有的attributes/parameters
  • 遍历找到所有的children nn.Module并编译
两者取舍

个人观点(也包括跟在PyTorch组工作的Engineer讨论得出的结论)

  • 大部分情况model只有tensor operation,就直接无脑
    tracing
  • 带control-flow (if-else, for-loop) 的,上
    scripting
  • 碰上
    scripting
    不能handle的语法,要么重写,要么把
    tracing

    scripting
    合起来用(比如说只在有control-flow的代码用
    scripting
    ,其他用
    tracing
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值