PyTorch修改网络结构后加载预训练模型,博客给出的代码实测有效。以下为主要代码。详情访问原博客。
net = AlexNet(num_classes=5, init_weights=True)
net.to(device)
# net.load_state_dict(torch.load('alexnet_model.pth'))
net_dict = net.state_dict()
predict_model = torch.load('alexnet_model.pth')
print('start')
state_dict = {k: v for k, v in predict_model.items() if k in net_dict.keys()}
# 寻找网络中公共层,并保留预训练参数
print(state_dict.keys())
net_dict.update(state_dict) # 将预训练参数更新到新的网络层
net.load_state_dict(net_dict) # 加载预训练参数