模型的保存与加载
PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)
- torch.save主要参数: obj:对象 、f:输出路径
- torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu
一、常见的模型保存的两种方法:
1、保存整个Module
torch.save(net, path)
2、保存模型参数
state_dict = net.state_dict()
torch.save(state_dict , path)
二、训练过程中自定义保存内容 与 断点恢复训练
#加载恢复
if RESUME:#是否恢复
path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # 断点模型文件路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler
#保存
for epoch in range(start_epoch + 1, 80):
optimizer.zero_grad()
optimizer.step()
lr_schedule.step()
if epoch %20 == 19:#每隔20个epoch保存一次模型
print('epoch:',epoch)
print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
#自定义要保存的参数信息
checkpoint = {
"net": model.state_dict(),
'optimizer': optimizer.state_dict(),
"epoch": epoch,
'lr_schedule': lr_schedule.state_dict()
}
if not os.path.isdir("./model_parameter/test"):
os.mkdir("./model_parameter/test")
torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
模型层该改变时,再次载入之前模型的权重,只需要 model.load_state_dict(torch.load(PATH), strict = False)