解决预训练模型和自己的模型因单卡和多卡训练方式不同,导致不匹配的问题
以要加载的预训练模型是单卡训练,而我的模型需要进行多卡训练为列:
单卡训练的模型load进来之后是这样的:
多卡训练的模型load进来是这样的:
可以很明显的看出来,多卡训练的模型在键值对的键处多了7个字符’module.’ 。
我们加载预训练权重的代码是这样的:
path = ‘预训练模型的地址’
mldel_dict = model.state_dict()
pretrained_dict = torch.load(path)
pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict, strict=False)
在更新pretrained_dict时,因为键的参数不匹配, pretrained_dict中的k 在循环的时候, if k in model_dict不满足, 比如当k 为 预训练模型pretrained_dict中的 input_proj.proj.0.bias 时,本地模型中的键是module. input_proj.proj.0.bias , if 语句不满足,所以预训练模型的的pretrained_dict就无法根据本地模型的model_dict进行更新,也就无法进行最后的model更新。
解决的办法其实就是我们在匹配pretrained_dict和model_dict时,关注前七个字符‘module.’, 但是我们不能简单地更改if语句, 比如
pretrained_dict = {k:v for k, v in pretrained_dict.items() if ‘module.’+ k in model_dict}
尽管此时的if语句满足,pretrain_dict也进行了和model_dict 的匹配,但是在运行后两句的时候还是无法更新,因为此时的预训练模型和自己的模型在键的字符上还是不一致,只是在if语句判断的时候做的操作并不影响根本的结构变化。
所以,我们的做法是,可以直接对我们自己多卡训练的模型model_dict,先进行复制,再匹配和更新。
path = ‘预训练模型的地址’
mldel_dict = model.state_dict()
pretrained_dict = torch.load(path)
new_model_dict = {}# 创建一个新的model_dict结构
for old_k, v in model_dict.items():
new_k = k[7:]
new_model_dict[new_k] = v
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in new_model_dict}
new_model_dict.update(pretrained_dict)
model.load_state_dict(new_model_dict, strict=False)