1.参考原文档:Introduction to Torchscript
TorchScript是PyTorch模型(nn.Module的子类)的中间表示,可以C++等高性能环境中运行。
2. 把torch模型转成torchscript脚本保存,需要调用torch.jit.trace,不能直接保存模型,需要提供一个输入样例,跟踪记录对这个样例的操作,把python的代码转成torchscript的脚本。
x = torch.ones(1,3,640,640)
trace_model = torch.jit.trace(model,x)
trace_model.save('model.pt')
如果代码中有if条件控制,尽量避免使用torch.jit.trace来转换代码,因为它不能处理变化的条件,如果非要用trace的话,可以把if条件控制改成别的形式。例如:self.mode == “UCBA”
if self.mode == "UCBA":
return self.conv(self.up(x))
elif self.mode == "DeconvBN":
return F.relu(self.bn(self.dconv(x)))
elif self.mode == "DeCBA":
return self.conv(self.dconv(x))
可以改成:
# if self.mode == "UCBA":
# return self.conv(self.up(x))
# elif self.mode == "DeconvBN":
# return F.relu(self.bn(self.dconv(x)))
# elif self.mode == "DeCBA":
# return self.conv(self.dconv(x))
return self.conv(self.up(x))
- torch.jit.trace记录的只是当前代码走的路径,同一个代码,这次走if分支,下次走else分支,那么torch.jit.trace记录的就会不同,在这种情况下,我们可以用torch.jit.script。
x = torch.ones(1,3,640,640)
script_model = torch.jit.script(model,x)
script_model.save('model.pt')
- torchscript模型加载
loaded = torch.jit.load('model.pt')
5.Scripting和Tracing的混合使用
有些情况需要使用跟踪而不是脚本(例如,一个模块有许多基于常量Python值的架构决策,我们不希望在TorchScript中出现)。在这种情况下,脚本可以由跟踪组成:torch.jit.script将内联被跟踪模块的代码,而跟踪将内联脚本模块的代码。
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)