这篇文章主要讨论 PyTorch JIT (Just In Time Compilation) 与TorchScript。JIT 和 TorchScript 之间有什么联系?它们如何为 Pytorch 引入静态特性,并方便模型部署?接下来,我们试着回答这些问题。
JIT
JIT 是一种概念,译为“即时编译”,是一种程序优化的方法。我们引用 PyTorch JIT 这篇文章中的栗子:
在 Python 中使用正则表达式:
prog = re.compile(pattern)
result = prog.match(string)
或者
result = re.match(pattern, string)
两种写法从结果上来说是等价的。但注意,第一种写法中,会先对正则表达式进行编译(compile),然后再进行使用。
Python 官方建议,如果多次使用到某一个正则表达式,建议先对其进行编译,然后再通过编译得到的对象做正则匹配。这个编译的过程,就可以理解为 JIT(即时编译)。
在深度学习中 JIT 的思想也是随处可见:比如 Keras 框架中的 model.compile
。在模型计算之前,为模型建立静态计算图,这样做有两个好处:一是在计算之前进行优化;二是方便多次复用计算图。
而与此相对,Pytorch 使用的是动态计算图。它没有 model.compile
这个过程。可以说,Pytorch 牺牲了一些高级特性,换取了易用性。
TorchScript
TorchScript 是 PyTorch 的 JIT 实现。它带来了几个明显的好处:
- 模型部署
PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便地调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等… ——引自文章 PyTorch JIT
- 性能提升
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型torch.nn.Module
转换为TorchScript Module
,再进行推断。——引自文章 PyTorch JIT
- 模型可视化
TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 forward 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式。——引自文章 PyTorch JIT
TorchScript Module 的两种生成方式
TorchScript 有 追踪(Tracing)和编码(Scripting) 两种生成方式。
Tracing 是一种比较简单的方式,可以直接将 PyTorch 模型 torch.nn.Module
转换成 TorchScript Module
。顾名思义,“追踪” 就是提供一个 “输入”,在模型中前向传播一遍;通过该输入的流转路径,获得计算图的结构。这种方式对于前向传播逻辑简单的模型来说非常实用,但如果前向传播里面本身夹杂了很多流程控制语句,则可能会有问题,因为同一个输入不可能遍历到所有的逻辑分支。
下面是一个利用 torch.jit.trace
生成 TorchScript 的栗子:
from torchvision.models import resnet18
model = resnet18()
model.eval()
# 通过 Tracing 的方式需要一个输入样例
dummy_input = torch.rand(1, 3, 224, 224)
jit_model = torch.jit.trace(model, dummy_input)
对于前向传播逻辑复杂的模型来说,需要用到 Scripting 这种方式转换为 TorchScript。这种情况下,需要利用 TorchScript Language
编写模型,然后用 torch.jit.script
转换成 TorchScript
TorchScript Language
是 Python 的一个子集。它本身属于 Python 代码,但只支持 Python 的部分特性。官方文档 TORCHSCRIPT LANGUAGE REFERENCE 中列出了所有 TorchScript Language
支持的 Python 特性,没有被提及的均不支持。
序列化
两种方法创建的 TorchScript 都可以进行序列化:
# 进行序列化
torch.jit.save(jit_module, "jit_module.pth")
# 加载序列化后的模型
jit_model = torch.jit.load('jit_model.pth')
序列化后的模型不再与 python 相关,可以被部署到各种平台上。PyTorch 提供了可以用于 TorchScript 模型推理的 C++ API,序列化后的模型可以不依赖 python 进行推理。
auto module = torch::jit::load(PATH);
...
torch::Tensor output = module.forward(inputs).toTensor();
这里有一个 Github 仓库,用简单的模型演示了如何用 Python 训练模型,并转换为 TorchScript,然后用 C++ 加载 TorchScript 做推断的完整工作流。
与 torch.onnx 的关系
ONNX 是广泛使用的一种神经网络中间表示,可以作为各个深度学习框架的转换媒介。
torch.onnx.export
函数可以把 PyTorch 模型转换成 ONNX 模型。而转换的步骤之一,就是利用 Tracing 的方式生成 TorchScript。
参考: