pytorch 导出模型有两种方式,一种是 torch.jit.script,一种是 torch.jit.trace
jit.script
如果网络中用到了 索引,if 条件判断等,很有可能会失败。
注意事项
1. 使用索引寻找模块
如果神经网络中使用了索引 self.ppms[k]
for k in range(len(self.ppms)):
xls.append(F.interpolate(self.ppms[k](xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True))
则需改为
for k, m in enumerate(self.ppms):
xls.append(F.interpolate(m(xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True))
如果是这样
for i in range(3):
y = self.convs[i](self.pools[i](x))
则需改为
for i, (p, m) in enumerate(zip(self.pools, self.convs)):
y = m(p(x))
2. 模块的初始化需要if条件
这种情况下,会报错说找不到 conv_sum_c 模块。
def Net(nn.module):
def __init__(self, need_fuse):
if self.need_fuse:
self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False)
def forward(self, x):
resl = x
if self.need_fuse:
resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3))
return resl
用法
scripted_module = torch.jit.script(net)
scripted_module.save("net.pt")
jit.trace
需要提供一个输入。可以用于比较复杂的网络。
input = torch.ones(1, 3, 224, 224)
traced_module = torch.jit.trace(net, input)
traced_module.save("net.pt")