- 第一种
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net = resnet34(num_classes=5)#模型可以带fc层,并设输出channel为5
pre_weights = torch.load(model_weight_path, map_location=device)
# print(pre_weights)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key) # 这时del_key为['fc.weight', 'fc.bias']
for key in del_key:
del pre_weights[key]#pth权重中,带有fc的删掉
net.load_state_dict(pre_weights,strict=False)
为什么strict=False,👇,不设置的话,会报没有fc层参数的错
- 第二种,迁移学习官方的方法
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)