pytorch中把ImageFolder划分做K折交叉验证的代码框架

在图像分类任务中,有时候需要把图像数据分为K折做交叉验证,以评估模型的性能。但是在pytorch中并没有相应的代码,因此自己写了一个框架,以便调用

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import numpy as np


class ImageFolderSplitKFold:
    '''
    path: 保存图像的文件
    K: 划分的折数
    '''

    def __init__(self, path, K=5):
        self.path = path
        self.K = K
        self.class2num = {}
        self.num2class = {}
        self.class_nums = {}
        self.data_x_path = []
        self.data_y_label = []

        for root, dirs, files in os.walk(path):
            if len(files) == 0 and len(dirs) > 1:
                for i, dir1 in enumerate(dirs):
                    self.num2class[i] = dir1
                    self.class2num[dir1] = i
            elif len(files) > 1 and len(dirs) == 0:
                category = ""
                for key in self.class2num.keys():
                    if key in root:
                        category = key
                        break
                label = self.class2num[category]
                self.class_nums[label] = 0
                for file1 in files:
                    self.data_x_path.append(os.path.join(root, file1))
                    self.data_y_label.append(label)
                    self.class_nums[label] += 1
            else:
                raise RuntimeError("please check the folder structure!")

        self.StratifiedKFoldData = {}
        skf = StratifiedKFold(n_splits=self.K)
        skf.get_n_splits(self.data_x_path, self.data_y_label)
        print(skf)
        i = 1
        for train_index, test_index in skf.split(self.data_x_path, self.data_y_label):
            X_train, X_test = np.array(self.data_x_path)[train_index], np.array(self.data_x_path)[test_index]
            y_train, y_test = np.array(self.data_y_label)[train_index], np.array(self.data_y_label)[test_index]
            name = f'K{i}'
            self.StratifiedKFoldData[name] = ((X_train, y_train), (X_test, y_test))
            i += 1

    def getKFoldData(self):
        '''
        返回一个字典,字典里共包含K个键值对。
        keys: K1, K2, K3, ....
        values: ((x_train,y_train),(x_test,y_test))  其中的 x_train 包含K-1份,x_test包含1份
        examples:  假如K=5, 用1,2,3,4,5代表5折,
                  则:    x_train(y_train)    x_test(y_test)
                         1,2,3,4                    5
                         1,2,3,5                    4
                         1,2,4,5                    3
                         1,3,4,5                    2
                         2,3,4,5                    1
        '''
        return self.StratifiedKFoldData


class DatasetFromFilename(Dataset):
    # x: a list of image file full path
    # y: a list of image label
    def __init__(self, x, y, transforms=None):
        super(DatasetFromFilename, self).__init__()
        self.x = x
        self.y = y
        if transforms == None:
            self.transforms = ToTensor()
        else:
            self.transforms = transforms

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        img = img.convert("RGB")
        return self.transforms(img), torch.tensor(self.y[idx])


def main():

    # device = ...
    # criterion = ...

    ImageSplit = ImageFolderSplitKFold('.....data_dir', 5)
    DATA = ImageSplit.getKFoldData()
    idx_to_class = ImageSplit.num2class
    my_transforms = transforms.Compose([
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomHorizontalFlip()
    ])

    for k, key in enumerate(DATA.keys()):  # K折交叉验证

        # model = ...
        # model.to(device)
        # optimizer = ...

        (x_train, y_train), (x_test, y_test) = DATA[key]
        training_dataset = DatasetFromFilename(x_train, y_train, transforms=my_transforms)
        test_dataset = DatasetFromFilename(x_test, y_test, transforms=my_transforms)
        train_loader = DataLoader(training_dataset, batch_size=16, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

        for epoch in range(20):
            pass
            # train for one epoch, printing every 10 iterations

            # 训练
            # train_one_epoch(model, criterion, optimizer, train_loader, idx_to_class, device, epoch)

            # 测试
            # evaluate(model, criterion, test_loader, idx_to_class, device)




if __name__ == '__main__':
    main()





当然可以!以下是一个使用PyTorch进行k折交叉验证的示例代码: ```python import torch from torch.utils.data import DataLoader, SubsetRandomSampler from torchvision.datasets import CIFAR10 from torchvision.transforms import ToTensor from sklearn.model_selection import KFold # 设置随机种子 torch.manual_seed(0) # 加载数据集 dataset = CIFAR10(root='data/', train=True, download=True, transform=ToTensor()) # 定义模型 model = YourModel() # 定义损失函数和优化器 criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # 设置k折交叉验证的折数 k = 5 # 初始化k折交叉验证器 kfold = KFold(n_splits=k, shuffle=True) # 执行k折交叉验证 for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)): print(f'Fold {fold + 1}/{k}') print('-' * 10) # 根据折数生成训练集和验证集的数据加载器 train_loader = DataLoader(dataset, batch_size=32, sampler=SubsetRandomSampler(train_ids)) val_loader = DataLoader(dataset, batch_size=32, sampler=SubsetRandomSampler(val_ids)) # 训练模型 for epoch in range(10): model.train() running_loss = 0.0 for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() epoch_loss = running_loss / len(train_loader) print(f'Train Loss: {epoch_loss}') # 在验证集上评估模型 model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in val_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total print(f'Validation Accuracy: {accuracy}') print() ``` 请注意,以上代码仅为示例,您可能需要根据您的数据集和模型进行适当的修改。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值