前提
模型训练好后自然想要将里面所有层涉及的权重保存下来,这样子我们的模型就能部署在任意有pytorch环境下了。
Torch.save/load
先介绍一下纯py接口的保存方式。
class my_model(nn.Modules):
def __init__(self):
super(my_module,self).__init__()
self.relu = nn.Relu()
def forward(x, self):
return self.relu(x)
......
......
model = my_model()
torch.save(model.state_dict(),"./model_name.pth")
torch.load(model,"./model_name.pth)
这样就会在当前目录保存一份.pth文件了(里面只保存了这个模型的所有权重即.parameters())。
下面的load就是在其他脚本中使用这个模型预训练好的权重
Torch.save 官网有详细介绍
Torch.jit
这个方式的保存更加高级(保存为TorchScript)可以与Torch c++接口通用。
这样带来的好处就是保存下来的模型为编译过后的运行时不需要python解释器,运行速度会更快。并且这种方式可以连带模型的定义一起保存,无需import model。
一般有两种保存方式
- torch.jit.trace
这种方式为追踪一个函数的执行流,使用时需要提供一个测试输入。
官网有样例。
Torch.jit.trace
需要注意的是这个接口只追踪测试输入走过的函数执行流(如果模型中有多条分支的话只会保存测试输入走过的分支!!!!!),所以对于一些多分支的模型不要采用这种方式,采用下面的Torch.jit.script。比如model.eval()和model.train()可以控制模型内BN层和dropout的权重是否固定,如果采用这种方式只能保留其中之一状态(固定或不固定)。 - torch.jit.script
使用这种方式可以将一个模型完整的保存下来,和上面的trace正好相对。如果模型中的分支很多,并且在运行时会改变的话一定要用这种形式保存。
Torch.jit.trace
这里简单写一下如何使用:(我这个测试模型里面只有一个简单的relu所以随便输入一个Tensor就行了)
store = torch.jit.trace(model,torch.randn(1,2,3,dtype=torch.float32))
store = torch.jit.script(model)
torch.jit.save(store,"./model_name.pth")
torch.jit.load("./model_name.pth")