pytorch 加载通道不对齐预训练

目录

加载预训练,自定义分类数,删掉分类层参数

yolov5加载预训练,把尺度相同的层拷贝参数,否则就不拷贝,重新训练

pytorch 自己完善的 加载通道不对齐预训练,会把参数切分或者补齐进行拷贝

这个代码加载预训练模型后,再训练无效果:

这个方式加载预训练模型后,新网络可以沿用旧网络的权重,可以正常训练:

完善版,可以正常训练:


加载预训练,自定义分类数,删掉分类层参数

  from easydict import EasyDict as edict

    opts = edict({'width_multiplier':1,
                  "model.classification.n_classes":2,
                  'attn_norm_layer': "layer_norm_2d", "model.normalization.name": "batch_norm", "model.activation.name": "swin"})

    model = MobileViTv3(opts)

    pth_path = r'.\mobilevitv3_1_0_0\checkpoint_ema_best.pt'

    state_dict = torch.load(pth_path, map_location='cpu')

    compatible_state_dict = {}
    for k, v in state_dict.items():
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            if "classifier" in k:
                continue
            compatible_state_dict[k] = v

    model.load_state_dict(compatible_state_dict, strict=False)

yolov5加载预训练,把尺度相同的层拷贝参数,否则就不拷贝,重新训练

def intersect_dicts(da, db, exclude=()):
    # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
    return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}



ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else []  # exclude keys
csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude)  # intersect
model.load_state_dict(csd, strict=False)  # load

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值