pytorch中的while for 循环 导出onnx的问题

文章讲述了如何将一个包含循环的PyTorch模块转换为可执行的ScriptModule,以便在执行时不再受限于固定参数的静态图。作者演示了如何使用`torch.jit.trace`进行转换,并进一步导出为ONNX模型以提高性能和兼容性。
摘要由CSDN通过智能技术生成

问题:

for执行次数不跟据输入而改变。

解决方案:

torch.jit.script

例如:

class LoopAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        h = x
        for i in range(x.size(0)):
            h = h + 1
        return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,
      %x : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu)):
  %7 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %8 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %h.1 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%x, %7, %8) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %10 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %h : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h.1, %10, %11) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %13 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %14 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:8:0
  %15 : Float(3, 16, strides=[16, 1], requires_grad=0, device=cpu) = aten::add(%h, %13, %14) # /home/mark.yj/GPT-SoVITS/b.py:8:0
  return (%15)

改造为ScriptModule:

class LoopAdd(torch.jit.ScriptModule):
    def __init__(self):
        super().__init__()
    @torch.jit.script_method
    def forward(self, x):
        h = x
        for i in range(x.size(0)):
            h = h + 1
        return h
input_1 = torch.ones(3, 16)
model = LoopAdd()
traced_model = torch.jit.trace(model, (input_1, ))
print(traced_model.graph)
graph(%self : __torch__.LoopAdd,
      %x.1 : Tensor):
  %8 : bool = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:18:8
  %4 : int = prim::Constant[value=0]() # /home/mark.yj/GPT-SoVITS/b.py:18:30
  %11 : int = prim::Constant[value=1]() # /home/mark.yj/GPT-SoVITS/b.py:19:20
  %5 : int = aten::size(%x.1, %4) # /home/mark.yj/GPT-SoVITS/b.py:18:23
  %h : Tensor = prim::Loop(%5, %8, %x.1) # /home/mark.yj/GPT-SoVITS/b.py:18:8
    block0(%i : int, %h.9 : Tensor):
      %h.3 : Tensor = aten::add(%h.9, %11, %11) # /home/mark.yj/GPT-SoVITS/b.py:19:16
      -> (%8, %h.3)
  return (%h)

可以看到 prim::Loop ,说明不再是固定参数的静态图了。

转ScriptModule

将模型转换为 torch.jit.ScriptModule
使用 torch.jit.trace_module() 跟踪模型并输入样本
使用 torch.onnx.export() 导出 ONNX 模型

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值