pytorch k 折

### Load dataset

trainset = datasets.ImageFolder(train_dir, transform=transform)
testset = datasets.ImageFolder(test_dir, transform=transform)

import torch
from sklearn.model_selection import KFold

data_induce = np.arange(0, len(trainset))  # 将“训练集”分为训练集和验证集
kf = KFold(n_splits=5)  # 分成 5 份

for k, (train_index, val_index) in enumerate(kf.split(data_induce)):

    print('{} - FOLD'.format(k))

    train_subset = torch.utils.data.dataset.Subset(trainset, train_index)
    val_subset = torch.utils.data.dataset.Subset(trainset, val_index)
    trainloader = DataLoader(dataset=train_subset, batch_size=bs, pin_memory=True)
    valloader = DataLoader(dataset=val_subset, batch_size=bs, pin_memory=True)

    ### Build model
    criterion = nn.CrossEntropyLoss()
    model = torchvision.models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)  ## 2 classes
    model.to(DEVICE)

    ### Design optimizer
    optimizer = optim.Adam(model.parameters(), lr=modellr)

    for epoch in range(1, EPOCHS + 1):
        adjust_learning_rate(optimizer, epoch)
        train(model, DEVICE, trainloader, optimizer, epoch, k)
        val(model, DEVICE, valloader, k)
    torch.save(model, 'model_{}_fold.pth'.format(k))

在原创基础上加工,原创有点找不到了,仅作记录

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值