问题再现
之前训练的很好的model(mAP=80)保存之后,在另一个文件里加载,结果效果很差劲(mAP=3);
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
if is_best:
torch.save(state, os.path.split(filename)[0] + '/model_best.pth.tar')
else:
torch.save(state, os.path.split(filename)[0] + filename)
if mAP_ema > mAP:
mAP = mAP_ema
state_dict = ema_m.module.state_dict()
else:
state_dict = model.state_dict()
save_checkpoint({
'epoch': epoch + 1,
'state_dict': state_dict,
'best_mAP': best_mAP,
'optimizer' : optimizer.state_dict(),
}, is_best=is_best, filename=os.path.join(args.output, 'checkpoint.pth.tar'))
用 model.module 替代单独的 model
if mAP_ema > mAP:
mAP = mAP_ema
state_dict = ema_m.module.state_dict()
else:
state_dict = model.module.state_dict()
保存模型,重新在另外一个模型加载,跑一遍validate()
,最后结果也很棒;
所以这个方法是有效的;