【pytorch载入模型参数报错以及解决办法,小心使用strict=False】

pytorch载入模型参数报错以及解决办法:

代码:

model = DenseNet121(nnClassCount, nnIsTrained)
pathModel = './models/m-25012018-123527.pth.tar'
checkpoint = torch.load(pathModel)
model.load_state_dict(checkpoint['state_dict'])

问题1:RuntimeError: Error(s) in loading state_dict for DenseNet121: Missing key(s) in state_dict: Unexpected key(s) in state_dict:模型载入参数键不匹配造成的报错报错内容

原因

如报错所示:对于model来说:权重参数的索引键前没有module,而将要载入的参数有,所以不匹配,造成报错,可以在debug中看出:
model_debug
checkpoin_debug

解决办法:

将即将要载入的参数中不匹配的键多余部分,‘module.’删除就可匹配:

params = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()} 	#替换将要载入的参数的键的不匹配部分

问题2:

上述方法可以解决部分网上问题,但是可能是由于代码更新,参数不匹配可能不止发生在多余的’.module’上,后面也可能造成参数键的不匹配,下面举例:
报错
可以看到,model中对应键是”densenet121.features.denseblock1.denselayer1.norm1.weight“
即将载入的权重参数对应的键是“densenet121.features.denseblock1.denselayer1.norm.1.weight”
norm1和norm.1造成的键不匹配,在debug中还可以发现conv1和conv.1的不匹配等造成的报错

解决办法:

将可能会报错的部分进行替换

params = {k.replace('norm.1','norm1'):v for k,v in params.items()}
params = {k.replace('norm.2','norm2'):v for k,v in params.items()}
params = {k.replace('conv.1','conv1'):v for k,v in params.items()}
params = {k.replace('conv.2','conv2'):v for k,v in params.items()}

总结以及注意点

  • 在载入参数到模型中时很有可能造成键的不匹配,最重要还是要debug看到参数的具体目录
  • 小心使用torch.load_state_dict(params,strict=False),参数strict最好不要设置为False,虽然此举可以使代码不报错,但是实际上代码跳过了检查模型键不匹配的问题,从而造成最后的随机预测结果,如果发现载入参数的模型预测效果很差,就要考虑代码是否设置了此项。
  • 12
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值