checkpiont中最关的部分为 model.state_dict(),以下方法都围绕其展开
1.保存为字典
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
...
}, PATH)
加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
2.以上的加强版,保存是否为当前最好的结果
# 保存函数
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
"""Saves checkpoint to disk"""
directory = "../models/%s/"%(args.name)
if not os.path.exists(directory):
os.makedirs(directory)
filename = directory + filename
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, '../models/%s/'%(args.name) + 'model_best.pth.tar')
# 用例
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
3.增强(将模型的保存与加载分别用两个函数实现,同时保存网络和优化器的参数)
保存
def saver(model_state_dict, optimizer_state_dict, model_path, epoch, max_to_save=30):
total_models = glob.glob(model_path + '*')
if len(total_models) >= max_to_save:
total_models.sort()
os.remove(total_models[0])
state_dict = {}
state_dict["model_state_dict"] = model_state_dict
state_dict["optimizer_state_dict"] = optimizer_state_dict
torch.save(state_dict, model_path + '_' + str(epoch))
print('models {} save successfully!'.format(model_path + '-' + str(epoch)))
## 使用
from check import loader, saver
saver(model.state_dict(), optimizer.state_dict(), model_save_path, epoch + 1, step=None, max_to_save=100)
加载
def loader(model_path):
state_dict = torch.load(model_path)
model_state_dict = state_dict["model_state_dict"]
optimizer_state_dict = state_dict["optimizer_state_dict"]
return model_state_dict, optimizer_state_dict
注:
pytorch的checkpoint主要用于节省训练模型过程中使用的内存(from torch.utils.checkpoint import checkpoint),将模型或其部分的激活值的计算方法保存为一个checkpoint,在前向传播中不保留激活值,而在反向传播中根据checkpoint重新计算一次获得激活值用于反向传播。
如果:RuntimeError: params/unet.pth is a zip archive (did you mean to use torch.jit.load()?
是因为pytorch的版本不匹配造成的。