PyTorch Eager mode and Script mode


本文大概总结一下近期对 pytorch 中的 eager 模式还有 script 模式的学习所得。

前言

断断续续接触这两个概念有很长一段时间了,但是始终觉得对这两个 pytorch 的重要特性的概念就是比较模糊,中间还夹杂了一个 JIT trace 的概念,让我一句话归纳总结它们就是:

  • Eager 模式:Python + Python runtime。这种模式是更 Pythonic 的编程模式,可以让用户很方便的使用 python 的语法来使用并调试框架,就像我们刚认识 pytorch 时它的样子,自带 eager 属性。(但是我始终对这个 eager 有点对不上号 T _ T)
  • Script 模式:TorchScript + PyTorch JIT。这种模式会对 eager 模式的模型创建一个中间表示(intermediate representation,IR),这个 IR 经过内部优化的,并且可以使用 PyTorch JIT 编译器去运行模型,不再依赖 python runtime,也可以使用 C++ 加载运行。
Script model

PyTorch 深受人们的喜爱主要是因为它的灵活和易用性(畏难心理,我到现在都还是对 TF 有点排斥),但是在模型部署方面,PyTorch 的表现却不尽人意,性能及可移植性都欠缺。之前使用 PyTorch 的痛点也是从研究到产品跨度比较大,不能直接将模型用来部署,为了解决这个 gap,PyTorch 提出了 TorchScript,想要通过它来实现从研究到产品的框架统一,通过TorchScript得到的模型可以脱离 python 的 runtime 并使你的模型跑的更快。

  • 可移植性:script 模式可以不用再使用 python runtime,因此可以用在多线程推理服务器,移动设备,自动驾驶等 python 很难应用的场景。
  • 性能表现:PyTorch JIT 是可以对 PyTorch 模型做特定优化的 JIT 编译器,其可以利用 runtime 的信息做量化,层融合,稀疏化等 Script 模型优化加速模型。

TorchScript 是一种编程语言,是 Python 的静态类型子集,它有自己的语法规则,我们使用 eager 模式来进行原型验证及训练的过程都是直接使用 python 语法,所以想得到方便部署的 script mode 需要通过torch.jit.trace 或者是 torch.jit.script 去处理模型。

torch.jit.trace

torch.jit.trace() 把训练后得到的 eager 模型以及模型需要的输入数据作为接口输入,然后 tracer 会把数据在 eager 模型里运行一次,并且记录执行的 tensor 操作,记录的结果会保存成一个 TorchScript 模块。

但是它的主要缺点就是不支持控制流,数据结构(list,dict 等)和 python 结构,并且可能部分操作没有正确的被记录在 TorchScript 模块中,但是不会给任何警示信息,不能保证输出的一定是正确的 TorchScript 模块。

torch.jit.script

torch.jit.script 用作装饰器可以将你的代码转化成写成 TorchScript 语言,它转化出来的模型更冗长(携带更多的信息),但是更通用,经过些许修改就可以支持大部分的 PyTorch 模型。 也可以用作接口,直接将 eager 模型送入torch.jit.script(),无需再送入数据。它支持控制流以及一些 Python 的数据结构。但是它会省略常量节点,并需要类型转换,如果没有类型提供则默认是 Tensor 类型。

因为 torch.jit.trace() 不支持控制流,torch.jit.script() 不会记录常量节点,当我们需要记录常量节点又需要支持控制流时就可以把二者结合在一起,下面直接贴出官方示例:

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

可以得到:

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

或者:

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

可以得到:

def forward(self,
    argument_1: Tensor) -> Tensor:
  _0, h, = (self.loop).forward(argument_1, )
  return torch.relu(h)

参考文章:

PyTorch 是一个开源的深度学习框架,由 Facebook 的人工智能研究团队(FAIR)开发并维护。它结合了两个重要的功能:即 Torch 提供的强大 GPU 加速能力与 Python 编程语言的高度灵活性。以下是关于 PyTorch 的详细介绍: ### 核心特点 1. **动态计算图**: - 与其他静态声明式框架(如 TensorFlow 早期版本)不同的是,PyTorch 使用即时模式(Eager Mode),这意味着可以在运行时直接评估操作而无需先构建整个计算图。这使得调试更容易,并且更贴近常规编程习惯。 2. **易于使用的 API**: - PyTorch 设计简洁直观,拥有非常接近 NumPy 的语法风格,但同时支持自动求导机制 Autograd 来简化梯度计算过程。对于研究人员来说尤其友好,因为它允许快速原型设计和实验迭代。 3. **强大的社区支持**: - 自发布以来积累了庞大的开发者群体和技术资源库,包括官方教程、第三方扩展包以及活跃的问题解答论坛等。丰富的资料极大地促进了新用户的入门速度和技术交流氛围。 4. **跨平台兼容性好**: - 支持多种操作系统环境下的部署,无论是 Linux、Windows 还是 MacOS 都能顺利安装使用;此外还具备良好的硬件加速选项,能够充分利用 NVIDIA CUDA 技术发挥显卡性能优势。 5. **分布式训练和推理优化**: - 内置了高效的分布式训练工具,方便用户针对大规模数据集或复杂网络结构实施多机多GPU协同工作流程。同时也提供了一些实用组件帮助提升生产环境中模型推断效率。 ### 应用场景 由于上述特性,PyTorch 被广泛应用于各个领域内的机器学习任务中,特别是计算机视觉、自然语言处理等方面的研究项目里扮演着重要角色。例如 ResNet、BERT 等著名模型均是在 PyTorch 上实现和发展起来的经典案例之一。 总之,如果你正在寻找一款既强大又易学易用的深度学习利器,那么 PyTorch 绝对值得考虑!
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值