前言
迁移学习的方法被广泛应用于卷积神经网络,基于大数据集训练而得到的权重文件对数据具有强的特征提取能力,在此基础上针对特有数据集进行模型的二次训练(微调),能大大降低训练时长以及犯错成本。
当改变卷积神经网络模型结构后,原有的预训练权重将无法成功加载到已经改变了的模型中,以下提供了针对模型的修改,实现指定部分权重的加载。
代码如下:
# load pretrain weights
model_weight_path = "./resnet34_pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
# Official_Option
# net = resnet34()
# net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
# in_channel = net.fc.in_features
# net.fc = nn.Linear(in_channel, 5)
# The Other option
net = resnet34(num_classes=100)
net_State = net.state_dict()
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key or "layer4" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")