我们在做科研时,常常需要做实验。为此,参考了网上很多的教程,综合各自的写下以下内容。一方面为自己留点笔记,日后好学习。另一方面为各位朋友提供一定的参考。
以卷积神经网络的resnet网络模型为例,简要的说明如何引入模型或者修改模型来做实验。
1、引入模型
model_ft = models.resnet50(pretrained=True)
model_ft = model_ft.to(device)
这里只是引入模型,没有做任何修改。原模型中类别数为1000,因此我们在训练自己的数据集时,需要修改种类数。
2、修改自己数据集的类别
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, n) # 这里的n为类别数
model_ft = model_ft.to(device)
3、如果修改resnet的网络结构,比如加入注意力机制
model_ft = models.resnet50(pretrained=False)
net_dict = model_ft.state_dict()
predict_model = torch.load('resnet50-5c106cde.pth')
# 寻找网络中公共层,并保留预训练参数
state_dict = {k: v for k, v in predict_model.items() if k in net_dict.keys()}
n