项目场景:torch.compile模型后无法读取
使用torch.compile编译后的模型,不能直接load_state_dict
问题描述
执行model.load_state_dict(torch.load(model_path))时会报错
比如模型要读取gate.block1.1.weight的时候,没有这个key,但是会有‘_orig_mod.block1.1.weight’
RuntimeError: Error(s) in loading state_dict for moe:
Missing key(s) in state_dict: "gate.block1.1.weight", "gate.block1.2.weight", "gate.block1.2.bias", "gate.block1.2.running_mean", "gate.block1.2.running_var", "gate.block2.0.weight", "gate.block2.1.weight", "gate.block2.1.bias", "gate.block2.1.running_mean", "gate.block2.1.running_var", "gate.block3.1.weight", "gate.block3.2.weight", "gate.block3.2.bias", "gate.block3.2.running_mean", "gate.block3.2.running_var", "gate.fc1.0.weight", "gate.fc1.0.bias", "gate.fc2.0.weight", "gate.fc2.0.bias", "experts.label_branch.0.block1.1.weight", "experts.label_branch.0.block1.2.weight", "experts.label_branch.0.block1.2.bias", "experts.label_branch.0.block1.2.running_mean", "experts.label_branch.0.block1.2.running_var", "experts.label_branch.0.block2.0.weight", "experts.label_branch.0.block2.1.weight", "experts.label_branch.0.block2.1.bias", "experts.label_branch.0.block2.1.running_mean", "experts.label_branch.0.block2.1.running_var", "experts.label_branch.0.block3.1.weight", "experts.label_branch.0.block3.2.weight", "experts.label_branch.0.block3.2.bias", "experts.label_branch.0.block3.2.running_mean", "experts.label_branch.0.block3.2.running_var", "experts.label_branch.0.fc1.0.weight", "experts.label_branch.0.fc1.0.bias", "experts.label_branch.0.fc2.0.weight", "experts.label_branch.0.fc2.0.bias", "experts.label_branch.1.temp_conv.weight", "experts.label_branch.1.temp_conv.bias", "experts.label_branch.1.spat_conv.weight", "experts.label_branch.1.BN.weight", "experts.label_branch.1.BN.bias", "experts.label_branch.1.BN.running_mean", "experts.label_branch.1.BN.running_var", "experts.label_branch.1.fc1.weight", "experts.label_branch.1.fc1.bias", "experts.label_branch.1.fc2.weight", "experts.label_branch.1.fc2.bias".
Unexpected key(s) in state_dict: "_orig_mod.gate.block1.1.weight", "_orig_mod.gate.block1.2.weight", "_orig_mod.gate.block1.
原因分析:
提示:这里填写问题的分析:torch.compile之后的模型,权重字典里key前面会多一串’_orig_mod.experts.label_branch‘’
解决方案:
删掉即可,读原keys然后替换为没有前缀的keys
def load_weights(path):
# torch2.0版本以上支持torch.compile来跑模型,会快,但是compile还只支持linux系统
# compile后的模型存权重的时候层的名字前面会加上'_orig_mod.'
# 这段代码就是把这个删掉
# 传入模型的路径(.pth),返回权重,直接使用model.load_state_dict(weight)就能读进去
weight = torch.load(path)
new_weight = weight.copy()
keys_list = list(weight.keys())
for key in keys_list:
if 'orig_mod.' in key:
del_key = key.replace('_orig_mod.', '')
new_weight[del_key] = weight[key]
del new_weight[key]
return new_weight
# 原来的, 不能跑
# model.load_state_dict(torch.load(model_path))
# 改成下面的就能用
weight = load_weights(model_path)
model.load_state_dict(weight)