PyTorch加载预训练模型
最近想在自己复现的VGG上加载torchvision官方的预训练模型vgg16_bn-6c64b313.pth,但是发现参数名不对应,找了很久发现大多是直接调用torchvision实现的vgg预训练模型。但是我的代码修改了vgg的部分层,无奈之下只能手动加载参数。以前总是用本子记着,最后都经常找不到。这次挂网上好了,同时想着交流一下 (说不定就有大佬看到了告诉我有更好的方法呢?)
直接上代码好了:
def load_from_pretrain(self, model_path):
# 当前网络的参数
model_dict = self.state_dict()
# 加载vgg模型
load_dict = torch.load(model_path)
key = list(model_dict.keys())
name = list(load_dict.keys())
weights = list(load_dict.values())
t = 0
for i in range(len(weights)):
# 不加载最后的全连接层
if 'classifier' in name[i]:
break
# 当前模型使用BN层多一个num_batches_tracked,但是加载的模型中没有,因此需要跳过
if 'num_batches_tracked' in key[i+t]:
t += 1
model_dict[key[i+t]] = weights[i]
self.load_state_dict(model_dict, strict=False)
基本思路就是把加载的模型内的参数值,按顺序给当前模型对应位置的值;
没有的值就直接跳过。
补充:
为了自己以后看得懂,顺便放一下vgg16_bn模型的key和对应value的size:
for k, v in model_dict.items():
print(k, v.shape)
输出:
features.0.weight torch.Size([64, 3, 3, 3])
features.0.bias torch.Size([64])
features.1.weight torch.Size([64])
features.1.bias torch.Size([64])
features.1.running_mean torch.Size([64])
features.1.running_var torch.Size([64])
features.3.weight torch.Size([64, 64, 3, 3])
features.3.bias torch.Size([64])
features.4.weight torch.Size([64])
features.4.bias torch.Size([64])
features.4.running_mean torch.Size([64])
features.4.running_var torch.Size([64])
features.7.weight torch.Size([128, 64, 3, 3])
features.7.bias torch.Size([128])
features.8.weight torch.Size([128])
features.8.bias torch.Size([128])
features.8.running_mean torch.Size([128])
features.8.running_var torch.Size([128])
features.10.weight torch.Size([128, 128, 3, 3])
features.10.bias torch.Size([128])
features.11.weight torch.Size([128])
features.11.bias torch.Size([128])
features.11.running_mean torch.Size([128])
features.11.running_var torch.Size([128])
features.14.weight torch.Size([256, 128, 3, 3])
features.14.bias torch.Size([256])
features.15.weight torch.Size([256])
features.15.bias torch.Size([256])
features.15.running_mean torch.Size([256])
features.15.running_var torch.Size([256])
features.17.weight torch.Size([256, 256, 3, 3])
features.17.bias torch.Size([256])
features.18.weight torch.Size([256])
features.18.bias torch.Size([256])
features.18.running_mean torch.Size([256])
features.18.running_var torch.Size([256])
features.20.weight torch.Size([256, 256, 3, 3])
features.20.bias torch.Size([256])
features.21.weight torch.Size([256])
features.21.bias torch.Size([256])
features.21.running_mean torch.Size([256])
features.21.running_var torch.Size([256])
features.24.weight torch.Size([512, 256, 3, 3])
features.24.bias torch.Size([512])
features.25.weight torch.Size([512])
features.25.bias torch.Size([512])
features.25.running_mean torch.Size([512])
features.25.running_var torch.Size([512])
features.27.weight torch.Size([512, 512, 3, 3])
features.27.bias torch.Size([512])
features.28.weight torch.Size([512])
features.28.bias torch.Size([512])
features.28.running_mean torch.Size([512])
features.28.running_var torch.Size([512])
features.30.weight torch.Size([512, 512, 3, 3])
features.30.bias torch.Size([512])
features.31.weight torch.Size([512])
features.31.bias torch.Size([512])
features.31.running_mean torch.Size([512])
features.31.running_var torch.Size([512])
features.34.weight torch.Size([512, 512, 3, 3])
features.34.bias torch.Size([512])
features.35.weight torch.Size([512])
features.35.bias torch.Size([512])
features.35.running_mean torch.Size([512])
features.35.running_var torch.Size([512])
features.37.weight torch.Size([512, 512, 3, 3])
features.37.bias torch.Size([512])
features.38.weight torch.Size([512])
features.38.bias torch.Size([512])
features.38.running_mean torch.Size([512])
features.38.running_var torch.Size([512])
features.40.weight torch.Size([512, 512, 3, 3])
features.40.bias torch.Size([512])
features.41.weight torch.Size([512])
features.41.bias torch.Size([512])
features.41.running_mean torch.Size([512])
features.41.running_var torch.Size([512])
classifier.0.weight torch.Size([4096, 25088])
classifier.0.bias torch.Size([4096])
classifier.3.weight torch.Size([4096, 4096])
classifier.3.bias torch.Size([4096])
classifier.6.weight torch.Size([1000, 4096])
classifier.6.bias torch.Size([1000])
还有我的模型的key和对应value的size(其实这个模型是ctrl+v过来的):
for k, v in model_dict.items():
print(k, v.shape)
输出:
conv11.weight torch.Size([64, 3, 3, 3])
conv11.bias torch.Size([64])
bn11.weight torch.Size([64])
bn11.bias torch.Size([64])
bn11.running_mean torch.Size([64])
bn11.running_var torch.Size([64])
bn11.num_batches_tracked torch.Size([])
conv12.weight torch.Size([64, 64, 3, 3])
conv12.bias torch.Size([64])
bn12.weight torch.Size([64])
bn12.bias torch.Size([64])
bn12.running_mean torch.Size([64])
bn12.running_var torch.Size([64])
bn12.num_batches_tracked torch.Size([])
conv21.weight torch.Size([128, 64, 3, 3])
conv21.bias torch.Size([128])
bn21.weight torch.Size([128])
bn21.bias torch.Size([128])
bn21.running_mean torch.Size([128])
bn21.running_var torch.Size([128])
bn21.num_batches_tracked torch.Size([])
conv22.weight torch.Size([128, 128, 3, 3])
conv22.bias torch.Size([128])
bn22.weight torch.Size([128])
bn22.bias torch.Size([128])
bn22.running_mean torch.Size([128])
bn22.running_var torch.Size([128])
bn22.num_batches_tracked torch.Size([])
conv31.weight torch.Size([256, 128, 3, 3])
conv31.bias torch.Size([256])
bn31.weight torch.Size([256])
bn31.bias torch.Size([256])
bn31.running_mean torch.Size([256])
bn31.running_var torch.Size([256])
bn31.num_batches_tracked torch.Size([])
conv32.weight torch.Size([256, 256, 3, 3])
conv32.bias torch.Size([256])
bn32.weight torch.Size([256])
bn32.bias torch.Size([256])
bn32.running_mean torch.Size([256])
bn32.running_var torch.Size([256])
bn32.num_batches_tracked torch.Size([])
conv33.weight torch.Size([256, 256, 3, 3])
conv33.bias torch.Size([256])
bn33.weight torch.Size([256])
bn33.bias torch.Size([256])
bn33.running_mean torch.Size([256])
bn33.running_var torch.Size([256])
bn33.num_batches_tracked torch.Size([])
conv41.weight torch.Size([512, 256, 3, 3])
conv41.bias torch.Size([512])
bn41.weight torch.Size([512])
bn41.bias torch.Size([512])
bn41.running_mean torch.Size([512])
bn41.running_var torch.Size([512])
bn41.num_batches_tracked torch.Size([])
conv42.weight torch.Size([512, 512, 3, 3])
conv42.bias torch.Size([512])
bn42.weight torch.Size([512])
bn42.bias torch.Size([512])
bn42.running_mean torch.Size([512])
bn42.running_var torch.Size([512])
bn42.num_batches_tracked torch.Size([])
conv43.weight torch.Size([512, 512, 3, 3])
conv43.bias torch.Size([512])
bn43.weight torch.Size([512])
bn43.bias torch.Size([512])
bn43.running_mean torch.Size([512])
bn43.running_var torch.Size([512])
bn43.num_batches_tracked torch.Size([])
conv51.weight torch.Size([512, 512, 3, 3])
conv51.bias torch.Size([512])
bn51.weight torch.Size([512])
bn51.bias torch.Size([512])
bn51.running_mean torch.Size([512])
bn51.running_var torch.Size([512])
bn51.num_batches_tracked torch.Size([])
conv52.weight torch.Size([512, 512, 3, 3])
conv52.bias torch.Size([512])
bn52.weight torch.Size([512])
bn52.bias torch.Size([512])
bn52.running_mean torch.Size([512])
bn52.running_var torch.Size([512])
bn52.num_batches_tracked torch.Size([])
conv53.weight torch.Size([512, 512, 3, 3])
conv53.bias torch.Size([512])
bn53.weight torch.Size([512])
bn53.bias torch.Size([512])
bn53.running_mean torch.Size([512])
bn53.running_var torch.Size([512])
bn53.num_batches_tracked torch.Size([])
conv53d.weight torch.Size([512, 512, 3, 3])
conv53d.bias torch.Size([512])
bn53d.weight torch.Size([512])
bn53d.bias torch.Size([512])
bn53d.running_mean torch.Size([512])
bn53d.running_var torch.Size([512])
bn53d.num_batches_tracked torch.Size([])
conv52d.weight torch.Size([512, 512, 3, 3])
conv52d.bias torch.Size([512])
bn52d.weight torch.Size([512])
bn52d.bias torch.Size([512])
bn52d.running_mean torch.Size([512])
bn52d.running_var torch.Size([512])
bn52d.num_batches_tracked torch.Size([])
conv51d.weight torch.Size([512, 512, 3, 3])
conv51d.bias torch.Size([512])
bn51d.weight torch.Size([512])
bn51d.bias torch.Size([512])
bn51d.running_mean torch.Size([512])
bn51d.running_var torch.Size([512])
bn51d.num_batches_tracked torch.Size([])
conv43d.weight torch.Size([512, 512, 3, 3])
conv43d.bias torch.Size([512])
bn43d.weight torch.Size([512])
bn43d.bias torch.Size([512])
bn43d.running_mean torch.Size([512])
bn43d.running_var torch.Size([512])
bn43d.num_batches_tracked torch.Size([])
conv42d.weight torch.Size([512, 512, 3, 3])
conv42d.bias torch.Size([512])
bn42d.weight torch.Size([512])
bn42d.bias torch.Size([512])
bn42d.running_mean torch.Size([512])
bn42d.running_var torch.Size([512])
bn42d.num_batches_tracked torch.Size([])
conv41d.weight torch.Size([256, 512, 3, 3])
conv41d.bias torch.Size([256])
bn41d.weight torch.Size([256])
bn41d.bias torch.Size([256])
bn41d.running_mean torch.Size([256])
bn41d.running_var torch.Size([256])
bn41d.num_batches_tracked torch.Size([])
conv33d.weight torch.Size([256, 256, 3, 3])
conv33d.bias torch.Size([256])
bn33d.weight torch.Size([256])
bn33d.bias torch.Size([256])
bn33d.running_mean torch.Size([256])
bn33d.running_var torch.Size([256])
bn33d.num_batches_tracked torch.Size([])
conv32d.weight torch.Size([256, 256, 3, 3])
conv32d.bias torch.Size([256])
bn32d.weight torch.Size([256])
bn32d.bias torch.Size([256])
bn32d.running_mean torch.Size([256])
bn32d.running_var torch.Size([256])
bn32d.num_batches_tracked torch.Size([])
conv31d.weight torch.Size([128, 256, 3, 3])
conv31d.bias torch.Size([128])
bn31d.weight torch.Size([128])
bn31d.bias torch.Size([128])
bn31d.running_mean torch.Size([128])
bn31d.running_var torch.Size([128])
bn31d.num_batches_tracked torch.Size([])
conv22d.weight torch.Size([128, 128, 3, 3])
conv22d.bias torch.Size([128])
bn22d.weight torch.Size([128])
bn22d.bias torch.Size([128])
bn22d.running_mean torch.Size([128])
bn22d.running_var torch.Size([128])
bn22d.num_batches_tracked torch.Size([])
conv21d.weight torch.Size([64, 128, 3, 3])
conv21d.bias torch.Size([64])
bn21d.weight torch.Size([64])
bn21d.bias torch.Size([64])
bn21d.running_mean torch.Size([64])
bn21d.running_var torch.Size([64])
bn21d.num_batches_tracked torch.Size([])
conv12d.weight torch.Size([64, 64, 3, 3])
conv12d.bias torch.Size([64])
bn12d.weight torch.Size([64])
bn12d.bias torch.Size([64])
bn12d.running_mean torch.Size([64])
bn12d.running_var torch.Size([64])
bn12d.num_batches_tracked torch.Size([])
conv11d.weight torch.Size([20, 64, 3, 3])
conv11d.bias torch.Size([20])