神经网络的训练往往需要一定的时间,如果训练过程中需要临时中断,其训练参数的保存与重新加载显得至关重要。
模型参数的保存:
# 模型参数保存,model是网络模型,optimizer是优化器
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, '文件名.pth')
模型参数的加载:
checkpoint=torch.load('文件名.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("加载成功!")
或者使用以下方式:
import os
# 保存模型参数
torch.save(model.state_dict(), 'model_para/model_net.pt')
torch.save(optimizer.state_dict(), 'model_para/optimizer.pt')
print("模型参数保存成功!")
# 加载模型参数
if os.path.exists("./'model_para/model_net.pt'"):
model.load_state_dict(torch.load("./model_para/model_net.pt"))
optimizer.load_state_dict(torch.load("./model_para/optimizer.pt"))
print("模型参数加载成功!")