[PyTorch] jit.script 与 jit.trace

pytorch 导出模型有两种方式,一种是 torch.jit.script,一种是 torch.jit.trace

jit.script

如果网络中用到了 索引,if 条件判断等,很有可能会失败。

注意事项

1. 使用索引寻找模块

如果神经网络中使用了索引 self.ppms[k]

for k in range(len(self.ppms)):
    xls.append(F.interpolate(self.ppms[k](xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True))

则需改为

for k, m in enumerate(self.ppms):
    xls.append(F.interpolate(m(xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True))

如果是这样

for i in range(3):
	y = self.convs[i](self.pools[i](x))

则需改为

for i, (p, m) in enumerate(zip(self.pools, self.convs)):
	y = m(p(x))

2. 模块的初始化需要if条件

这种情况下,会报错说找不到 conv_sum_c 模块。

def Net(nn.module):
	def __init__(self, need_fuse):
        if self.need_fuse:
            self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)
    def forward(self, x):
    	resl = x
        if self.need_fuse:
            resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3))
        return resl

用法

scripted_module = torch.jit.script(net)
scripted_module.save("net.pt")

jit.trace

需要提供一个输入。可以用于比较复杂的网络。

input = torch.ones(1, 3, 224, 224)
traced_module = torch.jit.trace(net, input)
traced_module.save("net.pt")
  • 5
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值