跑baseline遇到的问题:
- 第一次使用kaggle。国内手机号无法验证通过,导致无法联网使用gpu服务器
解决方案:http://t.csdnimg.cn/rMopF
baseline运行结果:
粗略看代码的心得:
model = timm.create_model('resnet18', pretrained=True, num_classes=2)
# 使用二分类resnet18作为基础模型
criterion = nn.CrossEntropyLoss().cuda()
# loss选择交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), 0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
# optimizer使用Adam,lr选用0.005;设置scheduler逐级递减,递减步长4,递减参数0.85
for epoch in range(2):
scheduler.step()
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer, epoch)
val_acc = validate(val_loader, model, criterion)
if val_acc.avg.item() > best_acc:
best_acc = round(val_acc.avg.item(), 2)
torch.save(model.state_dict(), f'./model_{best_acc}.pt')
# 作为baseline,此处只进行了2epoch的微调训练量[盲猜为了节约跑通时间]
如果您觉得有意思,可否点赞收藏关注一下本蒟蒻!谢谢!