0、问题描述
当使用torch.nn.DataParallel()在多GPU上训练并保存的模型,直接在CPU上通过load_state_dict()加载模型时会报错,
state_dict=torch.load("DataParallel_pretrained_path")
model.load_state_dict(state_dict)
如下图:
也就是通过多GPU训练保存的模型,某些层名会多出"module."的前缀,导致层名不匹配
1、解决办法
也简单,只要在加载参数之前,将"module.xxxx.xxx"的层名修改为"xxx.xxx"即可
代码如下:
def DataParallel2CPU(model, pth_file):
state_dict = torch.load(pth_file) # 加载参数
new_state_dict = OrderedDict() # 新建字典
for k, v in state_dict.items(): # 遍历参数,并获取名和值
if k[:7] == "module.": # 如果名符合匹配,则截取后面的字符串作为新名字
k = k[7:] # remove `module.`
new_state_dict[k] = v
model.load_state_dict(new_state_dict) # 此时,"module."该前缀被清理掉了
return model