语义分割Deeplabev3plus更改预训练权重(GHOSTNET)

def init_weights(self):
    # 先读取下载的预训练的键,读取模型的键
    checkpoint = torch.load('F:/Code/pytorch-deeplab-dualattention/test/state_dict_73.98.pth')
    state_dict = OrderedDict()  # 对字典对象中的元素排序
    # convert data_parallal to model 改变键的名字    更改名:将下载的预训练的键进行改名,if判断语句有很多个,因为结构有变化
    i = 0
    for key in checkpoint:
        # 前24个
        if i in range(0, 6):
            # a = "backbone."
            # b = a + key
            # state_dict[b] = checkpoint[key]
            state_dict[key] = checkpoint[key]
        if i in range(6, 30):
            # a = "backbone.layer1.0"
            a = "layer1.0"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(30, 72):
            a = "layer1.1"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(72, 96):
            a = "layer1.2"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(96, 142):  
            a = "layer2.0"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(142, 170):
            a = "layer2.1"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(170, 212): 
            a = "layer3.0"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(212, 236):
            a = "layer3.1"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(236, 260):
            a = "layer3.2"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(260, 284):
            a = "layer3.3"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(284, 324):
            a = "layer3.4"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(324, 352):
            a = "layer3.5"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(352, 398):
            a = "layer4.0"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(398, 422):
            a = "layer4.1"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(422, 450):
            a = "layer4.2"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(450, 474):
            a = "layer4.3"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        if i in range(474, 502):
            a = "layer4.4"
            b = a + key[10:]
            state_dict[b] = checkpoint[key]
        i += 1

    # check loaded parameters and created model parameters  去掉module字符
    model_state_dict_ = self.state_dict()
    model_state_dict = OrderedDict()
    for key in model_state_dict_:
        model_state_dict[key] = model_state_dict_[key]

    # 检查权重格式  将不必要的键去掉
    for key in state_dict:
        if key in model_state_dict:
            if state_dict[key].shape != model_state_dict[key].shape:
                print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format(
                    key, model_state_dict[key].shape, state_dict[key].shape))
                state_dict[key] = model_state_dict[key]
        else:
            state_dict.pop(key)
            print('Drop parameter {}.'.format(key))

    for key in model_state_dict:
        if key not in state_dict:
            print('No param {}.'.format(key))
            state_dict[key] = model_state_dict[key]

    # 将权重的key与model的key统一
    model_key = list(model_state_dict_.keys())
    pretrained_key = list(state_dict.keys())
    pre_state_dict = OrderedDict()
    for k in range(len(model_key)):
        pre_state_dict[model_key[k]] = state_dict[pretrained_key[k]]

    self.load_state_dict(pre_state_dict, strict=True)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值