PyTorch加载预训练模型(vgg16_bn)

PyTorch加载预训练模型

最近想在自己复现的VGG上加载torchvision官方的预训练模型vgg16_bn-6c64b313.pth,但是发现参数名不对应,找了很久发现大多是直接调用torchvision实现的vgg预训练模型。但是我的代码修改了vgg的部分层,无奈之下只能手动加载参数。以前总是用本子记着,最后都经常找不到。这次挂网上好了,同时想着交流一下 (说不定就有大佬看到了告诉我有更好的方法呢?)

直接上代码好了:

    def load_from_pretrain(self, model_path):
        # 当前网络的参数
        model_dict = self.state_dict()
        # 加载vgg模型
        load_dict = torch.load(model_path)
        key = list(model_dict.keys())
        name = list(load_dict.keys())
        weights = list(load_dict.values())

        t = 0
        for i in range(len(weights)):
        	# 不加载最后的全连接层
            if 'classifier' in name[i]:
                break
            # 当前模型使用BN层多一个num_batches_tracked,但是加载的模型中没有,因此需要跳过
            if 'num_batches_tracked' in key[i+t]:
                t += 1
            model_dict[key[i+t]] = weights[i]

        self.load_state_dict(model_dict, strict=False)

基本思路就是把加载的模型内的参数值,按顺序给当前模型对应位置的值;
没有的值就直接跳过。

补充:

为了自己以后看得懂,顺便放一下vgg16_bn模型的key和对应value的size:

for k, v in model_dict.items():
	print(k, v.shape)

输出:

features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])

还有我的模型的key和对应value的size(其实这个模型是ctrl+v过来的):

for k, v in model_dict.items():
	print(k, v.shape)

输出:

conv11.weight torch.Size([64, 3, 3, 3])
conv11.bias torch.Size([64])
bn11.weight torch.Size([64])
bn11.bias torch.Size([64])
bn11.running_mean torch.Size([64])
bn11.running_var torch.Size([64])
bn11.num_batches_tracked torch.Size([])
conv12.weight torch.Size([64, 64, 3, 3])
conv12.bias torch.Size([64])
bn12.weight torch.Size([64])
bn12.bias torch.Size([64])
bn12.running_mean torch.Size([64])
bn12.running_var torch.Size([64])
bn12.num_batches_tracked torch.Size([])
conv21.weight torch.Size([128, 64, 3, 3])
conv21.bias torch.Size([128])
bn21.weight torch.Size([128])
bn21.bias torch.Size([128])
bn21.running_mean torch.Size([128])
bn21.running_var torch.Size([128])
bn21.num_batches_tracked torch.Size([])
conv22.weight torch.Size([128, 128, 3, 3])
conv22.bias torch.Size([128])
bn22.weight torch.Size([128])
bn22.bias torch.Size([128])
bn22.running_mean torch.Size([128])
bn22.running_var torch.Size([128])
bn22.num_batches_tracked torch.Size([])
conv31.weight torch.Size([256, 128, 3, 3])
conv31.bias torch.Size([256])
bn31.weight torch.Size([256])
bn31.bias torch.Size([256])
bn31.running_mean torch.Size([256])
bn31.running_var torch.Size([256])
bn31.num_batches_tracked torch.Size([])
conv32.weight torch.Size([256, 256, 3, 3])
conv32.bias torch.Size([256])
bn32.weight torch.Size([256])
bn32.bias torch.Size([256])
bn32.running_mean torch.Size([256])
bn32.running_var torch.Size([256])
bn32.num_batches_tracked torch.Size([])
conv33.weight torch.Size([256, 256, 3, 3])
conv33.bias torch.Size([256])
bn33.weight torch.Size([256])
bn33.bias torch.Size([256])
bn33.running_mean torch.Size([256])
bn33.running_var torch.Size([256])
bn33.num_batches_tracked torch.Size([])
conv41.weight torch.Size([512, 256, 3, 3])
conv41.bias torch.Size([512])
bn41.weight torch.Size([512])
bn41.bias torch.Size([512])
bn41.running_mean torch.Size([512])
bn41.running_var torch.Size([512])
bn41.num_batches_tracked torch.Size([])
conv42.weight torch.Size([512, 512, 3, 3])
conv42.bias torch.Size([512])
bn42.weight torch.Size([512])
bn42.bias torch.Size([512])
bn42.running_mean torch.Size([512])
bn42.running_var torch.Size([512])
bn42.num_batches_tracked torch.Size([])
conv43.weight torch.Size([512, 512, 3, 3])
conv43.bias torch.Size([512])
bn43.weight torch.Size([512])
bn43.bias torch.Size([512])
bn43.running_mean torch.Size([512])
bn43.running_var torch.Size([512])
bn43.num_batches_tracked torch.Size([])
conv51.weight torch.Size([512, 512, 3, 3])
conv51.bias torch.Size([512])
bn51.weight torch.Size([512])
bn51.bias torch.Size([512])
bn51.running_mean torch.Size([512])
bn51.running_var torch.Size([512])
bn51.num_batches_tracked torch.Size([])
conv52.weight torch.Size([512, 512, 3, 3])
conv52.bias torch.Size([512])
bn52.weight torch.Size([512])
bn52.bias torch.Size([512])
bn52.running_mean torch.Size([512])
bn52.running_var torch.Size([512])
bn52.num_batches_tracked torch.Size([])
conv53.weight torch.Size([512, 512, 3, 3])
conv53.bias torch.Size([512])
bn53.weight torch.Size([512])
bn53.bias torch.Size([512])
bn53.running_mean torch.Size([512])
bn53.running_var torch.Size([512])
bn53.num_batches_tracked torch.Size([])
conv53d.weight torch.Size([512, 512, 3, 3])
conv53d.bias torch.Size([512])
bn53d.weight torch.Size([512])
bn53d.bias torch.Size([512])
bn53d.running_mean torch.Size([512])
bn53d.running_var torch.Size([512])
bn53d.num_batches_tracked torch.Size([])
conv52d.weight torch.Size([512, 512, 3, 3])
conv52d.bias torch.Size([512])
bn52d.weight torch.Size([512])
bn52d.bias torch.Size([512])
bn52d.running_mean torch.Size([512])
bn52d.running_var torch.Size([512])
bn52d.num_batches_tracked torch.Size([])
conv51d.weight torch.Size([512, 512, 3, 3])
conv51d.bias torch.Size([512])
bn51d.weight torch.Size([512])
bn51d.bias torch.Size([512])
bn51d.running_mean torch.Size([512])
bn51d.running_var torch.Size([512])
bn51d.num_batches_tracked torch.Size([])
conv43d.weight torch.Size([512, 512, 3, 3])
conv43d.bias torch.Size([512])
bn43d.weight torch.Size([512])
bn43d.bias torch.Size([512])
bn43d.running_mean torch.Size([512])
bn43d.running_var torch.Size([512])
bn43d.num_batches_tracked torch.Size([])
conv42d.weight torch.Size([512, 512, 3, 3])
conv42d.bias torch.Size([512])
bn42d.weight torch.Size([512])
bn42d.bias torch.Size([512])
bn42d.running_mean torch.Size([512])
bn42d.running_var torch.Size([512])
bn42d.num_batches_tracked torch.Size([])
conv41d.weight torch.Size([256, 512, 3, 3])
conv41d.bias torch.Size([256])
bn41d.weight torch.Size([256])
bn41d.bias torch.Size([256])
bn41d.running_mean torch.Size([256])
bn41d.running_var torch.Size([256])
bn41d.num_batches_tracked torch.Size([])
conv33d.weight torch.Size([256, 256, 3, 3])
conv33d.bias torch.Size([256])
bn33d.weight torch.Size([256])
bn33d.bias torch.Size([256])
bn33d.running_mean torch.Size([256])
bn33d.running_var torch.Size([256])
bn33d.num_batches_tracked torch.Size([])
conv32d.weight torch.Size([256, 256, 3, 3])
conv32d.bias torch.Size([256])
bn32d.weight torch.Size([256])
bn32d.bias torch.Size([256])
bn32d.running_mean torch.Size([256])
bn32d.running_var torch.Size([256])
bn32d.num_batches_tracked torch.Size([])
conv31d.weight torch.Size([128, 256, 3, 3])
conv31d.bias torch.Size([128])
bn31d.weight torch.Size([128])
bn31d.bias torch.Size([128])
bn31d.running_mean torch.Size([128])
bn31d.running_var torch.Size([128])
bn31d.num_batches_tracked torch.Size([])
conv22d.weight torch.Size([128, 128, 3, 3])
conv22d.bias torch.Size([128])
bn22d.weight torch.Size([128])
bn22d.bias torch.Size([128])
bn22d.running_mean torch.Size([128])
bn22d.running_var torch.Size([128])
bn22d.num_batches_tracked torch.Size([])
conv21d.weight torch.Size([64, 128, 3, 3])
conv21d.bias torch.Size([64])
bn21d.weight torch.Size([64])
bn21d.bias torch.Size([64])
bn21d.running_mean torch.Size([64])
bn21d.running_var torch.Size([64])
bn21d.num_batches_tracked torch.Size([])
conv12d.weight torch.Size([64, 64, 3, 3])
conv12d.bias torch.Size([64])
bn12d.weight torch.Size([64])
bn12d.bias torch.Size([64])
bn12d.running_mean torch.Size([64])
bn12d.running_var torch.Size([64])
bn12d.num_batches_tracked torch.Size([])
conv11d.weight torch.Size([20, 64, 3, 3])
conv11d.bias torch.Size([20])

  • 9
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值