问题描述:
RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.fcc.weight:
pytorch代码,加载预训练模型时报错,分类类别数不一致
报错信息:
错误代码:
checkpoint = torch.load('pretrain.pth', map_location=device)
model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)
尝试:
checkpoint = torch.load('pretrain.pth', map_location=device)
del_keys = ['module.fcc.weight', ' module.fcc.bias', ' module.head.weight', 'module.head.bias']
for k in del_keys:
del checkpoint[k]
model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)
报错:
Debug发现checkpoint['model_state_dict'']才是要删除的预训练模型对应的字典
In[3]: checkpoint.keys()
Out[3]: dict_keys(['iter', 'model_state_dict'])
In[3]: checkpoint['model_state_dict'].keys()
Out[4]: odict_keys([ 'module.bn1.weight', 'module.bn1.bias', ......, 'module.head.weight', 'module.head.bias', ......])
则解决方案为:
checkpoint = torch.load('pretrain.pth', map_location=device)
del_key = []
for key, _ in checkpoint['model_state_dict'].items():
if 'fcc' in key:
del_key.append(key)
elif 'head' in key:
del_key.append(key)
else:
pass
for key in del_key:
del checkpoint['model_state_dict'][key]
model = nn.DataParallel(model)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)