如何在Pytorch中载入部分权重
很多时候,比如我们不想训练模型,想用预训练模型来进行测试,或者加载预训练模型来训练。但是预训练模型中网络权重已经训练好了,是一个整体。比如resnet网络默认输出的类别是1000类,但是我们现在的数据是5类,那怎么才能用上这个预训练模型呢?
答案是:我们可以载入部分权重
方法有两种
方法一
我们知道网络是由很多层堆叠起来的,默认是1000类的残差网络,前面的多层卷积层是不用修改的,但是最后一层的全连接层不满足要求,需要根据自己类别修改。
在初始话的时候不传入类别参数,直接修改全连接层的结构
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
print('net:', net)
可以看出,最后的全连接层已经修改成5了。
其实这里我有个疑问,这样只是修改了结构,但是权重不就还是之前的权重吗?
方法二
在模型初始化话的时候,传入自己的类别,这样网络的结构肯定是没问题的,最后全连接层输出是5。但是权重参数是按照1000训练的,加载权重时不要加载全连接层相关的参数即可。
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
weight_t = net.state_dict()[key].numpy()
print(key, ":", weight_t)
for key in del_key:
del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")
完整代码如下
import os
import torch
import torch.nn as nn
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
# option2
# net = resnet34(num_classes=5)
# pre_weights = torch.load(model_weight_path, map_location=device)
# del_key = []
# for key, _ in pre_weights.items():
# if "fc" in key:
# del_key.append(key)
# weight_t = net.state_dict()[key].numpy()
# print(key, ":", weight_t)
# for key in del_key:
# del pre_weights[key]
# missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
# print("[missing_keys]:", *missing_keys, sep="\n")
# print("[unexpected_keys]:", *unexpected_keys, sep="\n")
if __name__ == '__main__':
main()