pytorch保存模型有两种方法:
- 保存整个模型 (结构+参数)
- 只保存参数(官方推荐)
两者都是用torch.save(obj, dir)
实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。
保存整个模型
这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。
# 保存
model = Mymodel()
torch.save(model, path)
# 加载
model = torch.load(path)
Note:PyTorch约定使用.pt或.pth后缀命名保存文件。
保存参数
重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:
- 模型参数
- 优化器参数
- loss
- epoch
- args
把这些信息用字典包装起来,然后保存即可。
这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:
# 获得保存信息
save_data = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'epoch': epoch,
'args': args
...
}
# 保存
torch.save(save_data , path)
load_data = torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
# 加载参数
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...
Note:PyTorch约定使用.pt或.pth后缀命名保存文件。