出现的问题
- 使用
load_state_dict
函数,如果strict=False
,模型可以加载,但是模型的输出不对 - 如果
strict=True
,但是没有新建字典,模型无法加载
问题解决
- 保存的模型使用分布式方式训练,如果加载后模型不用分布式,则需要修改模型的key.
- 分布式模型的权值名字前面有
module.
,非分布式不包含。
比如分布式模型的权值名称名称为module.blocks.0.norm1.weight
,非分布式模型权值名称名称blocks.0.norm1.weight
- 需要重新建一个字典,用去掉
module.
的名字座位key,值不变
代码示例
chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
print('Loading checkpoint', chk_filename)
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
checkpoint = checkpoint['model_pos']
new_checkpoint = {} ## 新建一个字典来访模型的权值
# print(checkpoint)
for k,value in checkpoint.items():
key = k.split('module.')[-1]
new_checkpoint[key] = value
# print(k,key)
# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)
保存的模型重新加载后,与原始模型预测的结果不一致,从下面几个方面查找原因
- 是否使用了
model.eval()
- 模型结构里面存在dropout?
- 等
参考https://blog.csdn.net/pearl8899/article/details/109661274
参考文章2