torchscript

torchscript

原文: 【Pytorch】(十三)模型部署: TorchScript

Pytorch动态图的缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

在这里插入图片描述

动态图允许在运行时渐进地构建计算图(如上图所示,计算图表示计算过程),使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript

torchscript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:

import torch 
import torch.nn as nn

torch.manual_seed(191009)

class MyCell(nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.sigmoid(self.linear(x)) * h
        return new_h, new_h
    
my_cell = MyCell()
x,h = torch.rand(3,4),torch.rand(3,4)
traced_cell = torch.jit.trace(my_cell,(x,h))
print(traced_cell)
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_cell。traced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %19 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::sigmoid(%19) # C:\Users\10766\AppData\Local\Temp\ipykernel_30996\4271899992.py:12:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::mul(%11, %h) # C:\Users\10766\AppData\Local\Temp\ipykernel_30996\4271899992.py:12:0
  %13 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%12, %12)
  return (%13)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释

print(traced_cell.code)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.mul(torch.sigmoid((linear).forward(x, )), h)
  return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[0.2387, 0.1769, 0.0538, 0.0618],
        [0.3556, 0.2774, 0.3545, 0.3884],
        [0.2733, 0.1302, 0.2116, 0.2834]], grad_fn=<MulBackward0>), tensor([[0.2387, 0.1769, 0.0538, 0.0618],
        [0.3556, 0.2774, 0.3545, 0.3884],
        [0.2733, 0.1302, 0.2116, 0.2834]], grad_fn=<MulBackward0>))
(tensor([[0.2387, 0.1769, 0.0538, 0.0618],
        [0.3556, 0.2774, 0.3545, 0.3884],
        [0.2733, 0.1302, 0.2116, 0.2834]], grad_fn=<MulBackward0>), tensor([[0.2387, 0.1769, 0.0538, 0.0618],
        [0.3556, 0.2774, 0.3545, 0.3884],
        [0.2733, 0.1302, 0.2116, 0.2834]], grad_fn=<MulBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(nn.Module):
    def forward(self,x):
        if x.sum() > 0:
            return x 
        else: 
            return -x
        

class MyCell_1(nn.Module):
    def __init__(self,dg):
        super(MyCell_1,self).__init__()
        self.dg = dg
        self.linear = nn.Linear(4,4)

    def forward(self,x,h):
        new_h = torch.tanh(self.dg(self.linear(x)+h))
        return new_h,new_h
    

my_cell = MyCell_1(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell,(x,h))

print(traced_cell.dg.code)
print(traced_cell.code)
def forward(self,
    x: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  x0 = torch.add((linear).forward(x, ), h)
  _0 = (dg).forward(x0, )
  _1 = torch.tanh(x0)
  return (_1, _1)



C:\Users\10766\AppData\Local\Temp\ipykernel_30996\2585701402.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if x.sum() > 0:

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell_1(scripted_gate)
scripted_ceil = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_ceil.code)
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (dg).forward(torch.add((linear).forward(x, ), h), )
  new_h = torch.tanh(_0)
  return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

x,h = torch.rand(3,4),torch.rand(3,4)
print(scripted_ceil(x,h))
(tensor([[ 0.5790,  0.3942, -0.1636, -0.0422],
        [ 0.4143,  0.3593,  0.6575, -0.0613],
        [ 0.6404,  0.3927,  0.7470, -0.0614]], grad_fn=<TanhBackward0>), tensor([[ 0.5790,  0.3942, -0.1636, -0.0422],
        [ 0.4143,  0.3593,  0.6575, -0.0613],
        [ 0.6404,  0.3927,  0.7470, -0.0614]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.trace: torch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。
  • torch.jit.script: torch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销

因此,可以将两者混合使用,扬长避短。

trace 和script 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(MyRNNLoop,self).__init__(*args, **kwargs)
        self.cell = torch.jit.trace(MyCell_1(scripted_gate),(x,h))

    def forward(self,xs):
        h,y = torch.zeros(3,4),torch.zeros(3,4)
        for i in range(xs.size(0)):
            y,h = self.cell(xs[i],h)

        return y,h
    

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapLoop(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super(WrapLoop,self).__init__(*args, **kwargs)
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self,xs):
        y,h = self.loop(xs)
        return torch.relu(y)
    

traced = torch.jit.trace(WrapLoop(),(torch.rand(10,3,4)))
print(traced)
WrapLoop(
  original_name=WrapLoop
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): MyCell_1(
      original_name=MyCell_1
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): Linear(original_name=Linear)
    )
  )
)

保存和加载模型

  • traced.save : 保存TorchScript
  • torch.jit.load : 加载TorchScript
traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')
print(loaded)
print(loaded.code)
RecursiveScriptModule(
  original_name=WrapLoop
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell_1
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值