checkpoint = load_checkpoint(args.resume)
model_dict = model.state_dict()
checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items()
if k in model_dict}
model_dict.update(checkpoint_load)
model.load_state_dict(model_dict)
start_epoch = checkpoint['epoch']
best_top1 = checkpoint['best_top1']
print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1))
1.首先先读取arg.sume(已存储的权重)到checkpoint,相当于字典
2.再读取模型中的参数权重到model_dict
3.将checkpoint中key值对应model_dict的数据加载到checkpoint_load中
4.将已经训练好的模型参数更新并加载到已有模型参数中(单卡)
5.再读取checkpoint中的其他参数,以此类推
model.module.load_state_dict(checkpoint['state_dict'])
加载模型参数(多卡)
torch.save(model.state_dict(), model_out_path)
存储模型参数(单卡)
torch.save(state, fpath)
save_checkpoint({
'state_dict': model.module.state_dict(),
'epoch': epoch + 1,
'best_top1': best_top1
})
存储模型参数(多卡)以及其他信息,