pytorch载入模型参数报错以及解决办法,小心使用strict=False
pytorch载入模型参数报错以及解决办法:
代码:
model = DenseNet121(nnClassCount, nnIsTrained)
pathModel = './models/m-25012018-123527.pth.tar'
checkpoint = torch.load(pathModel)
model.load_state_dict(checkpoint['state_dict'])
问题1:RuntimeError: Error(s) in loading state_dict for DenseNet121: Missing key(s) in state_dict: Unexpected key(s) in state_dict:模型载入参数键不匹配造成的报错
原因
如报错所示:对于model来说:权重参数的索引键前没有module,而将要载入的参数有,所以不匹配,造成报错,可以在debug中看出:
解决办法:
将即将要载入的参数中不匹配的键多余部分,‘module.’删除就可匹配:
params = {k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()} #替换将要载入的参数的键的不匹配部分
问题2:
上述方法可以解决部分网上问题,但是可能是由于代码更新,参数不匹配可能不止发生在多余的’.module’上,后面也可能造成参数键的不匹配,下面举例:
可以看到,model中对应键是”densenet121.features.denseblock1.denselayer1.norm1.weight“
即将载入的权重参数对应的键是“densenet121.features.denseblock1.denselayer1.norm.1.weight”
norm1和norm.1造成的键不匹配,在debug中还可以发现conv1和conv.1的不匹配等造成的报错
解决办法:
将可能会报错的部分进行替换
params = {k.replace('norm.1','norm1'):v for k,v in params.items()}
params = {k.replace('norm.2','norm2'):v for k,v in params.items()}
params = {k.replace('conv.1','conv1'):v for k,v in params.items()}
params = {k.replace('conv.2','conv2'):v for k,v in params.items()}
总结以及注意点
- 在载入参数到模型中时很有可能造成键的不匹配,最重要还是要debug看到参数的具体目录
- 小心使用torch.load_state_dict(params,strict=False),参数strict最好不要设置为False,虽然此举可以使代码不报错,但是实际上代码跳过了检查模型键不匹配的问题,从而造成最后的随机预测结果,如果发现载入参数的模型预测效果很差,就要考虑代码是否设置了此项。