- 多gpu训练存储的参数经常会在load的时候由于多了module而错误,因此可以用下面代码去掉
from collections import OrderedDict
pretrained_dict = torch.load(pretraind)
new_state_dict = OrderedDict()
for k, v in pretrained_dict.items():
if k[0:6] == 'module':
name = k[7:] # remove `module.`
new_state_dict[name] = v
else:
new_state_dict[k[:]] = v
model.load_state_dict(new_state_dict)
- 另外一种是backbone相同,但head不同,这时候我们只需要一部分参数,我们同样可以跟上面相似的,只保留有的key,代码如下:
pretrained_dict = torch.load(pretraind)
model_dict = model.state_dict()
print(model_dict.keys())
print(pretrained_dict.keys())
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
做了一个判断,判断key是否在model的dict中也存在,如果存在就保留