https://blog.csdn.net/zzh2910/article/details/103987523
微调CNN:
net = torchvision.models.resnet18(pretrained=True) # 加载resnet网络结构和预训练参数
num_ftrs = net.fc.in_features # 提取fc层的输入参数
net.fc = nn.Linear(num_ftrs, 2) # 修改输出维度为2
net = net.to(device) # 使用分类交叉熵 Cross-Entropy 作损失函数,动量SGD做优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 每5个epochs衰减一次学习率
new_lr = old_lr * gamma ^ (epoch/step_size)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # 训练模型
net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=10)
CNN作为固定特征提取器:
net = torchvision.models.resnet18(pretrained=True) # 通过设置requires_grad = False来冻结参数,这样在反向传播的时候他们的梯度就不会被计算
for param in net.parameters():
param.requires_grad = False
# 新连接层参数默认 requires_grad=True
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2)
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.fc.parameters(), lr=0.001, momentum=0.9)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=20)