pytorch loading参数 迁移学习

  1. 多gpu训练存储的参数经常会在load的时候由于多了module而错误,因此可以用下面代码去掉
        from collections import OrderedDict
        pretrained_dict = torch.load(pretraind)
        new_state_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            if k[0:6] == 'module':
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            else:
                new_state_dict[k[:]] = v
         model.load_state_dict(new_state_dict)
  1. 另外一种是backbone相同,但head不同,这时候我们只需要一部分参数,我们同样可以跟上面相似的,只保留有的key,代码如下:
pretrained_dict = torch.load(pretraind)
model_dict = model.state_dict()
print(model_dict.keys())
print(pretrained_dict.keys())
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

做了一个判断,判断key是否在model的dict中也存在,如果存在就保留

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值