一、问题描述
多个GPU 训练,保存时没有加module , 导致加载模型时报错。正确写法应该如下:
# save model
if num_gpu == 1:
torch.save(model.state_dict(), os.path.join(opt.outf, 'model.pth'))
else:
torch.save(model.module.state_dict(), os.path.join(opt.outf, 'model.pth'))
就是在多卡训练的时候,存储模型权重的时候,用的是:
torch.save(model.module.state_dict(), os.path.join(opt.outf, 'model.pth'))
二、解决方法
load 模型时,删除多余的module,那个地方缺了“module”关键字,导致在保存模型参数时,参数保存成了这样(模型参数是以key-value的形式保存的),即stat_dict(key),对应的value每个值都多了一个‘module’,直接加载,会报错不匹配。可以将状态字典里的"module"关键字去掉,这样就可以了。