pytorch的pt模型保存与加载
- 仅保存和加载模型参数
# 模型保存
checkpoint={'epoch':epoch,
'best_loss':best_loss,
'model':model.state_dict()
'optimizer':optimizer.state_dict()
}
torch.save(checkpoint, PATH='model.pth') #只保存模型权重参数,不保存模型结构
# 模型调用
the_model = TheModelClass(*args, **kwargs) #需要重新模型结构
the_model.load_state_dict(torch.load('model.pth')) #根据模型结构,调用存储的模型参数
- 保存并加载整个模型
# 模型保存
torch.save(model, 'model.pth') #保存整个model的状态
# 模型调用
the_model = torch.load('model.pth') #这里已经不需要重构模型结构了,直接load就可以
总结:
第一种方式需自定义网络,并且其中的参数名称与结构要与保存的模型中的一致(可以是部分网络,比如只使用某个网络的前几层),相对灵活,便于对网络进行修改。
第二种方式无需自定义网络,保存时已把网络结构保存,不能调整网络结构。