fold_num = 0
for fold, (trn_idx, val_idx) in enumerate(folds):
if fold == fold_num:
print('Training with {} started'.format(fold))
print('Train : {}, Val : {}'.format(len(trn_idx), len(val_idx)))
train_loader, val_loader = utils.prepare_dataloader(train,
trn_idx,
val_idx,
data_root = train_img_path,
trn_transform = trn_transform,
val_transform = val_transform,
bs = CFG['train_bs'],
n_job = CFG['num_workers'])
device = torch.device(CFG['device'])
model = CassvaImgClassifier(CFG['model_arch'],
train.label.nunique(),
pretrained=True).to(device)
scaler = GradScaler()
optimizer = torch.optim.Adam(model.parameters(),
lr=CFG['lr'],
weight_decay=CFG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=CFG['T_0'],
T_mult=1,
eta_min=CFG['min_lr'],
last_epoch=-1)
loss_tr = nn.CrossEntropyLoss().to(
device)
loss_fn = nn.CrossEntropyLoss().to(device)
for epoch in range(CFG['epochs']):
utils.train_one_epoch(epoch,
model,
loss_tr,
optimizer,
train_loader,
device,
scaler,
scheduler=scheduler,
schd_batch_update=False)
with torch.no_grad():
utils.valid_one_epoch(epoch,
model,
loss_fn,
val_loader,
device)
torch.save(
model.state_dict(),
'../model/{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))
del model, optimizer, train_loader, val_loader, scaler, scheduler
torch.cuda.empty_cache()
kfold训练
最新推荐文章于 2023-03-08 12:07:12 发布