加载网络权重,去除全连接层的权重

仅作为记录,大佬请跳过。

感谢老师的示范。

fc_keys = [k for k in state_dict.keys() if "fc" in k]
for k in fc_keys:
    del state_dict[k]

查看设计的网络加载的网络权重的有没有不同的层

def load_from_pretrained(self, ckpt_path):
    print(f"==============> Loading weight {ckpt_path} for fine-tuning......")
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state_dict = ckpt

    fc_keys = [k for k in state_dict.keys() if "fc" in k]
    for k in fc_keys:
        del state_dict[k]

    from pprint import pprint
    missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
    print('missing_keys = ')
    pprint(missing_keys)
    print('unexpected_keys = ')
    pprint(unexpected_keys)
    print(f"=> loaded successfully '{ckpt_path}'")
    print('ok')

其中,self指设计的网络

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值