目录
模型保存/加载所有模型参数
第一种方法
第二种方法(不推荐)
模型保存/加载部分模型参数
Finetune
目前Pytorch模型
模型保存/加载所有模型参数
第一种方法
#保存模型到checkpoint.pth.tar
torch.save(model.module.state_dict(), ‘checkpoint.pth.tar’)
对应的加载模型方法为(这种方法需要先反序列化模型获取参数字典,因此必须先load模型,再load_state_dict):
mymodel.load_state_dict(torch.load(‘checkpoint.pth.tar’))
例子:说明如何在inference AND/OR resume train使用
#保存模型的状态,可以设置一些参数,后续可以使用
state = {'epoch': epoch + 1,#保存的当前轮数
'state_dict': mymodel.state_dict(),#训练好的参数
'optimizer': optimizer.state_dict(),#优化器参数,为了后续的resume
'best_pred': best_pred#当前最好的精度
,....,...}
#保存模型到checkpoint.pth.tar
torch.save(state, ‘checkpoint.pth.tar’)
#如果是best,则复制过去
if is_best:
shutil.copyfile(filename, directory