1、假设在某个epoch,我们要保存模型参数,优化器参数以及epoch
① 先建立一个字典,保存三个参数:
state = {‘net':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epoch':epoch}
2.调用torch.save():
torch.save(state, path)
其中path表示保存文件的绝对路径+文件名。
当你想恢复某一阶段的训练(或者进行测试)时,就可以读取之前保存的网络模型参数等。
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
3、当我们修改了一部分网络,比如加了一些,删除一些,需要过滤某些参数,加载方式:
def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint['state_dict']
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint['optimizer'])
print('loaded! optimizer')
else:
print('not loaded optimizer')
else:
print('No checkpoint is included')
return model, optimizer