kfold训练

fold_num = 0
for fold, (trn_idx, val_idx) in enumerate(folds):
    if fold == fold_num:
        print('Training with {} started'.format(fold))
        print('Train : {}, Val : {}'.format(len(trn_idx), len(val_idx)))
        train_loader, val_loader = utils.prepare_dataloader(train,
                                                          trn_idx,
                                                          val_idx,
                                                          data_root = train_img_path,
                                                          trn_transform = trn_transform,
                                                          val_transform = val_transform, 
                                                          bs = CFG['train_bs'], 
                                                          n_job = CFG['num_workers'])

        device = torch.device(CFG['device'])

        model = CassvaImgClassifier(CFG['model_arch'],
                                    train.label.nunique(),
                                    pretrained=True).to(device)
        scaler = GradScaler()
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=CFG['lr'],
                                     weight_decay=CFG['weight_decay'])

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=CFG['T_0'],
            T_mult=1,
            eta_min=CFG['min_lr'],
            last_epoch=-1)

        loss_tr = nn.CrossEntropyLoss().to(
            device)
        loss_fn = nn.CrossEntropyLoss().to(device)

        for epoch in range(CFG['epochs']):
            utils.train_one_epoch(epoch,
                                model,
                                loss_tr,
                                optimizer,
                                train_loader,
                                device,
                                scaler,
                                scheduler=scheduler,
                                schd_batch_update=False)

            with torch.no_grad():
                utils.valid_one_epoch(epoch,
                                    model,
                                    loss_fn,
                                    val_loader,
                                    device)

            torch.save(
                model.state_dict(),
                '../model/{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))

        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值