resume_path 为 checkpoint.pth 的文件路径
checkpoint = torch.load(resume_path, map_location=torch.device(‘cpu’))
修改训练开始轮数
args.start_epoch = checkpoint['epoch']
获取预训练权重的参数
new_param = checkpoint['state_dict']
加载模型参数
model.load_state_dict(new_param)
加载优化器
optimizer.load_state_dict(checkpoint['optimizer'])
代码来源于DCP模型(cvpr22)的源代码
################### args.resume为checkpoint.pth文件 ###################
if args.resume:
resume_path = osp.join(args.snapshot_path, args.resume)
if os.path.isfile(resume_path):
if main_process():
logger.info("=> loading checkpoint '{}'".format(resume_path))
################### 加载预训练权重 ###################
checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
################### 修改训练开始轮数 ###################
args.start_epoch = checkpoint['epoch']
################### 获取预训练权重的参数 ###################
new_param = checkpoint['state_dict']
try:
##################### 加载模型参数 ###################
model.load_state_dict(new_param)
except RuntimeError: # 1GPU loads mGPU model
for key in list(new_param.keys()):
new_param[key[7:]] = new_param.pop(key)
model.load_state_dict(new_param)
##################### 加载优化器###################
optimizer.load_state_dict(checkpoint['optimizer'])
if main_process():
logger.info("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
else:
if main_process():
logger.info("=> no checkpoint found at '{}'".format(resume_path))