关于 model.load_state_dict 加载部分预训练参数

之前一直使用迁移训练,现记录一下关于model.load_state_dict加载部分预训练参数的学习。

我们在训练模型的时候,会时常用到迁移训练,使用在imagenet上的预训练参数,但是imagenet数据集的类别为1000,与我们自己数据集的类别数目不同时,我们需要修改最后的分类器的类别数,这就会造成加载预训练参数的时候报错。

这里以shufflenetv2为例:

def shufflenetv2(width_mult=1., pretrained = False, n_class = 7):
    model = ShuffleNetV2(width_mult=width_mult, n_class=n_class)
    if pretrained:
        checkpoint = torch.load('../pretrained/shufflenet_v2/shufflenetv2_x1_69.402_88.374.pth.tar')
        model.load_state_dict(checkpoint)
    return model


if __name__ == "__main__":
    model = shufflenetv2(pretrained=True, n_class=7)
    input = torch.randn(1, 3, 244, 244)
    out = model(input)
    print(out.shape)

报错:

size mismatch for classifier.0.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([7, 1024]).
size mismatch for classifier.0.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([7]).

解决办法:

我从网上看了很多使用下面的方法,对比自定义模型与预训练模型的键,剔除不同的部分,然后重新加载。

def shufflenetv2(width_mult=1., pretrained = False, n_class = 7):
    model = ShuffleNetV2(width_mult=width_mult, n_class=n_class)
    if pretrained:
        model_dict = model.state_dict()
        checkpoint = torch.load('../pretrained/shufflenet_v2/shufflenetv2_x1_69.402_88.374.pth.tar')
        # 对比自定义模型与预训练模型的键,剔除不同的部分
        pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
        # 更新
        model_dict.update(pretrained_dict)
        # 加载
        model.load_state_dict(checkpoint, strict=False)
    return model

但是有一个问题就是,这种方法只能剔除键不同的部分。比如说,我们没有增删shufflenet的任何部分,只是修改了最后分类器的类别数,那么在对比的过程中,classifier是存在的,classifier部分的预训练参数(类别为1000)就会被加载,依然会报错size mismatch,因为我们将最后的类别数修改为7,和1000不匹配。

这个时候有两种解决办法:

1、直接修改shufflenet中的classifier的名字。比如:

class ShuffleNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(ShuffleNetV2, self).__init__()

        assert input_size % 32 == 0

        self.stage_repeats = [4, 8, 4]
        # index 0 is invalid and should never be called.
        # only used for indexing convenience.
        if width_mult == 0.5:
            self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
        elif width_mult == 1.0:
            self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
        elif width_mult == 1.5:
            self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
        elif width_mult == 2.0:
            self.stage_out_channels = [-1, 24, 224, 488, 976, 2048]
        else:
            raise ValueError(
                """{} groups is not supported for
                       1x1 Grouped Convolutions""".format(num_groups))

        # building first layer
        input_channel = self.stage_out_channels[1]
        self.conv1 = conv_bn(3, input_channel, 2)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.features = []
        # building inverted residual blocks
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                if i == 0:
                    # inp, oup, stride, benchmodel):
                    self.features.append(InvertedResidual(input_channel, output_channel, 2, 2))
                else:
                    self.features.append(InvertedResidual(input_channel, output_channel, 1, 1))
                input_channel = output_channel

        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building last several layers
        self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1])
        self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size / 32)))
        # building classifier
        self.classifier_my = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class))

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.conv_last(x)
        x = self.globalpool(x)
        x = x.view(-1, self.stage_out_channels[-1])
        x = self.classifier_my(x)
        return x

将self.classifier修改为self.classifier_my。

2、直接将预训练参数的classifier部分的参数剔除,从报错来看,classifier.0.weight 和classifier.0.bias与预训练参数匹配不上,因此将这两部分的参数剔除:

def shufflenetv2(width_mult=1., pretrained = False, n_class = 7):
    model = ShuffleNetV2(width_mult=width_mult, n_class=n_class)
    if pretrained:
        model_dict = model.state_dict()
        checkpoint = torch.load('../pretrained/shufflenet_v2/shufflenetv2_x1_69.402_88.374.pth.tar')
        pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        checkpoint.pop('classifier.0.weight')
        checkpoint.pop('classifier.0.bias')
        model.load_state_dict(checkpoint, strict=False)
    return model

仅记录一下学习的过程,若有不对,欢迎批评指正。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值