加载模型时出现Unexpected key(s) in state_dict错误
报错截图如下:
反复排查问题没发现为何如此,查看pytorch中文文档发现保存和加载模型方法都完全正确,模型保存和加载代码对比中文文档截图如下:
其中一个方法是在加载模型时添加参数strict=False,可以只保留键值相同的参数避免出错,用法如下:model.load_state_dict(ckpt[‘state_dict’],False),注意ckpt是我加载权重时自定义的对象名。
但是继续训练发现实际上原有的权重都没有加载进来,因为没有键值相同的参数。此处行不通
继续深究发现多一个module是因为通过多GPU训练保存的模型!如果是在cpu训练的模型则不会带有module,解决很简单,在加载权重前先将模型映射到GPU训练,
model = torch.nn.DataParallel(model)
ckpt = torch.load(args.weights)
model.load_state_dict(ckpt[‘state_dict’])
或者还有解决办法,直接将权重里面key值中含有的module.删掉,改写权重键值对
ckpt = torch.load(model_path,map_location=‘cpu’)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in ckpt.items():
name = k[7:] # remove module.
new_state_dict[name] = v
model.load_state_dict(new_state_dict)