【torch.compile模型后windows下无法读取】

项目场景:torch.compile模型后无法读取

使用torch.compile编译后的模型,不能直接load_state_dict


问题描述

执行model.load_state_dict(torch.load(model_path))时会报错
比如模型要读取gate.block1.1.weight的时候,没有这个key,但是会有‘_orig_mod.block1.1.weight’

RuntimeError: Error(s) in loading state_dict for moe:
	Missing key(s) in state_dict: "gate.block1.1.weight", "gate.block1.2.weight", "gate.block1.2.bias", "gate.block1.2.running_mean", "gate.block1.2.running_var", "gate.block2.0.weight", "gate.block2.1.weight", "gate.block2.1.bias", "gate.block2.1.running_mean", "gate.block2.1.running_var", "gate.block3.1.weight", "gate.block3.2.weight", "gate.block3.2.bias", "gate.block3.2.running_mean", "gate.block3.2.running_var", "gate.fc1.0.weight", "gate.fc1.0.bias", "gate.fc2.0.weight", "gate.fc2.0.bias", "experts.label_branch.0.block1.1.weight", "experts.label_branch.0.block1.2.weight", "experts.label_branch.0.block1.2.bias", "experts.label_branch.0.block1.2.running_mean", "experts.label_branch.0.block1.2.running_var", "experts.label_branch.0.block2.0.weight", "experts.label_branch.0.block2.1.weight", "experts.label_branch.0.block2.1.bias", "experts.label_branch.0.block2.1.running_mean", "experts.label_branch.0.block2.1.running_var", "experts.label_branch.0.block3.1.weight", "experts.label_branch.0.block3.2.weight", "experts.label_branch.0.block3.2.bias", "experts.label_branch.0.block3.2.running_mean", "experts.label_branch.0.block3.2.running_var", "experts.label_branch.0.fc1.0.weight", "experts.label_branch.0.fc1.0.bias", "experts.label_branch.0.fc2.0.weight", "experts.label_branch.0.fc2.0.bias", "experts.label_branch.1.temp_conv.weight", "experts.label_branch.1.temp_conv.bias", "experts.label_branch.1.spat_conv.weight", "experts.label_branch.1.BN.weight", "experts.label_branch.1.BN.bias", "experts.label_branch.1.BN.running_mean", "experts.label_branch.1.BN.running_var", "experts.label_branch.1.fc1.weight", "experts.label_branch.1.fc1.bias", "experts.label_branch.1.fc2.weight", "experts.label_branch.1.fc2.bias". 
	Unexpected key(s) in state_dict: "_orig_mod.gate.block1.1.weight", "_orig_mod.gate.block1.2.weight", "_orig_mod.gate.block1.

原因分析:

提示:这里填写问题的分析:torch.compile之后的模型,权重字典里key前面会多一串’_orig_mod.experts.label_branch‘’


解决方案:

删掉即可,读原keys然后替换为没有前缀的keys

def load_weights(path):
    # torch2.0版本以上支持torch.compile来跑模型,会快,但是compile还只支持linux系统
    # compile后的模型存权重的时候层的名字前面会加上'_orig_mod.'
    # 这段代码就是把这个删掉
    # 传入模型的路径(.pth),返回权重,直接使用model.load_state_dict(weight)就能读进去
    weight = torch.load(path)
    new_weight = weight.copy()
    keys_list = list(weight.keys())
    for key in keys_list:
        if 'orig_mod.' in key:
            del_key = key.replace('_orig_mod.', '')
            new_weight[del_key] = weight[key]
            del new_weight[key]
    return new_weight
# 原来的, 不能跑
# model.load_state_dict(torch.load(model_path))
# 改成下面的就能用
weight = load_weights(model_path)
model.load_state_dict(weight)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值