Pytorch 即时编译与 TorchScript

这篇文章主要讨论 PyTorch JIT (Just In Time Compilation) 与TorchScript。JIT 和 TorchScript 之间有什么联系?它们如何为 Pytorch 引入静态特性,并方便模型部署?接下来,我们试着回答这些问题。


JIT

JIT 是一种概念,译为“即时编译”,是一种程序优化的方法。我们引用 PyTorch JIT 这篇文章中的栗子:

在 Python 中使用正则表达式:

prog = re.compile(pattern)
result = prog.match(string)

或者

result = re.match(pattern, string)

两种写法从结果上来说是等价的。但注意,第一种写法中,会先对正则表达式进行编译(compile),然后再进行使用。

Python 官方建议,如果多次使用到某一个正则表达式,建议先对其进行编译,然后再通过编译得到的对象做正则匹配。这个编译的过程,就可以理解为 JIT(即时编译)。

在深度学习中 JIT 的思想也是随处可见:比如 Keras 框架中的 model.compile。在模型计算之前,为模型建立静态计算图,这样做有两个好处:一是在计算之前进行优化;二是方便多次复用计算图。

而与此相对,Pytorch 使用的是动态计算图。它没有 model.compile 这个过程。可以说,Pytorch 牺牲了一些高级特性,换取了易用性。


TorchScript

TorchScript 是 PyTorch 的 JIT 实现。它带来了几个明显的好处:

  1. 模型部署
    PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便地调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等… ——引自文章 PyTorch JIT
  1. 性能提升
    既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型 torch.nn.Module 转换为 TorchScript Module,再进行推断。——引自文章 PyTorch JIT
  1. 模型可视化
    TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 forward 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式。——引自文章 PyTorch JIT

TorchScript Module 的两种生成方式

TorchScript 有 追踪(Tracing)和编码(Scripting) 两种生成方式。

Tracing 是一种比较简单的方式,可以直接将 PyTorch 模型 torch.nn.Module 转换成 TorchScript Module。顾名思义,“追踪” 就是提供一个 “输入”,在模型中前向传播一遍;通过该输入的流转路径,获得计算图的结构。这种方式对于前向传播逻辑简单的模型来说非常实用,但如果前向传播里面本身夹杂了很多流程控制语句,则可能会有问题,因为同一个输入不可能遍历到所有的逻辑分支。

下面是一个利用 torch.jit.trace 生成 TorchScript 的栗子:

from torchvision.models import resnet18 
 
model = resnet18() 
model.eval() 

# 通过 Tracing 的方式需要一个输入样例 
dummy_input = torch.rand(1, 3, 224, 224) 

jit_model = torch.jit.trace(model, dummy_input) 

对于前向传播逻辑复杂的模型来说,需要用到 Scripting 这种方式转换为 TorchScript。这种情况下,需要利用 TorchScript Language 编写模型,然后用 torch.jit.script 转换成 TorchScript

TorchScript Language 是 Python 的一个子集。它本身属于 Python 代码,但只支持 Python 的部分特性。官方文档 TORCHSCRIPT LANGUAGE REFERENCE 中列出了所有 TorchScript Language 支持的 Python 特性,没有被提及的均不支持。


序列化

两种方法创建的 TorchScript 都可以进行序列化:

# 进行序列化
torch.jit.save(jit_module, "jit_module.pth")

# 加载序列化后的模型 
jit_model = torch.jit.load('jit_model.pth') 

序列化后的模型不再与 python 相关,可以被部署到各种平台上。PyTorch 提供了可以用于 TorchScript 模型推理的 C++ API,序列化后的模型可以不依赖 python 进行推理。

auto module = torch::jit::load(PATH);
...
torch::Tensor output = module.forward(inputs).toTensor();

这里有一个 Github 仓库,用简单的模型演示了如何用 Python 训练模型,并转换为 TorchScript,然后用 C++ 加载 TorchScript 做推断的完整工作流。


与 torch.onnx 的关系

ONNX 是广泛使用的一种神经网络中间表示,可以作为各个深度学习框架的转换媒介。

torch.onnx.export 函数可以把 PyTorch 模型转换成 ONNX 模型。而转换的步骤之一,就是利用 Tracing 的方式生成 TorchScript。


参考:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值