pytorch 载入模型参数finetune训练

pytorch 在finetune重训练时,采用torch.load()方式载入模型
经常会报错。
这里给出一种load方式,若模型中存在相同的tensor(名字和大小一致)则载入,否则对模型中tensor只做初始化处理。
代码如下:

def check_keys(model, pretrained_state_dict):
    ckpt_keys = set(pretrained_state_dict.keys())
    model_keys= set(model.state_dict().keys())
    used_pretrained_keys = model_keys & ckpt_keys
    unused_pretrained_keys = ckpt_keys-model_keys
    missing_keys = model_keys  - ckpt_keys
    assert len(uesd_pretrained_keys) > 0, 'load None from pretrained checkpoint'
    return True
def remove_prefix(state_dict, prefix):
   f = lambda x:x.split(prefix,1)[-1] if x.startswith(prefix) else x
   return {f(key):value for key,value in state_dict.items()}
    
def load_model(model,pretrained_path,load_to_cpu):
    if load_to_cpu:
        pretrained_dict  = torch.load(pretrained_path, map_location=lambda storage, loc:storage)
   else:
       device = torch.cuda.current_device()
       pretrained_dict =torch.load(pretrained_path,map_location=lambda storage, loc:storage.cuda(device))
   if "state_dict" in pretrained_dict.keys():
        pretrained_dict=remove_prefix(pretrained_dict['state_dict'],'module.')
   else:
       pretrained_dict = remove_prefix(pretrained_dict, 'module.')
  check_keys(model, pretrained_dict)
  model.load_state_dict(pretrained_dict,strict=False)
  return model

代码来源

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值