阅读官网文档发现pytorch除了使用torch.save与torch.load保存加载数据/模型外还可以使用torch.jit.save(),torch.jit.load()。
后者的好处在于
- 其他py脚本在使用模型时不需要将原模型所在py文档也引入
- 可以跨平台使用
官网代码
scripted_module = torch.jit.script(MyModule())
torch.jit.save(scripted_module, 'mymodule.pt')
torch.jit.load('mymodule.pt')
以上代码即可,注意torch.jit.save保存的时torchscript模型,所以要先进行转换。否则报错如下:
torch.nn.modules.module.ModuleAttributeError: 'Model