torch模型转换成torchscript格式

1.参考原文档:Introduction to Torchscript
TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以C++等高性能环境中运行。
2. 把torch模型转成torchscript脚本保存,需要调用torch.jit.trace,不能直接保存模型,需要提供一个输入样例,跟踪记录对这个样例的操作,把python的代码转成torchscript的脚本。

x = torch.ones(1,3,640,640)
trace_model = torch.jit.trace(model,x)
trace_model.save('model.pt')

如果代码中有if条件控制,尽量避免使用torch.jit.trace来转换代码,因为它不能处理变化的条件,如果非要用trace的话,可以把if条件控制改成别的形式。例如:self.mode == “UCBA”

  if self.mode == "UCBA":
       return self.conv(self.up(x))
  elif self.mode == "DeconvBN":
       return F.relu(self.bn(self.dconv(x)))
  elif self.mode == "DeCBA":
       return self.conv(self.dconv(x))

可以改成:

# if self.mode == "UCBA":
#     return self.conv(self.up(x))
# elif self.mode == "DeconvBN":
#     return F.relu(self.bn(self.dconv(x)))
# elif self.mode == "DeCBA":
#     return self.conv(self.dconv(x))
  return self.conv(self.up(x))
  1. torch.jit.trace记录的只是当前代码走的路径,同一个代码,这次走if分支,下次走else分支,那么torch.jit.trace记录的就会不同,在这种情况下,我们可以用torch.jit.script。
x = torch.ones(1,3,640,640)
script_model = torch.jit.script(model,x)
script_model.save('model.pt')
  1. torchscript模型加载
loaded = torch.jit.load('model.pt')

5.Scripting和Tracing的混合使用
有些情况需要使用跟踪而不是脚本(例如,一个模块有许多基于常量Python值的架构决策,我们不希望在TorchScript中出现)。在这种情况下,脚本可以由跟踪组成:torch.jit.script将内联被跟踪模块的代码,而跟踪将内联脚本模块的代码。

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值