pytorch中把ImageFolder划分做K折交叉验证的代码框架

在图像分类任务中,有时候需要把图像数据分为K折做交叉验证,以评估模型的性能。但是在pytorch中并没有相应的代码,因此自己写了一个框架,以便调用

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import numpy as np


class ImageFolderSplitKFold:
    '''
    path: 保存图像的文件
    K: 划分的折数
    '''

    def __init__(self, path, K=5):
        self.path = path
        self.K = K
        self.class2num = {}
        self.num2class = {}
        self.class_nums = {}
        self.data_x_path = []
        self.data_y_label = []

        for root, dirs, files in os.walk(path):
            if len(files) == 0 and len(dirs) > 1:
                for i, dir1 in enumerate(dirs):
                    self.num2class[i] = dir1
                    self.class2num[dir1] = i
            elif len(files) > 1 and len(dirs) == 0:
                category = ""
                for key in self.class2num.keys():
                    if key in root:
                        category = key
                        break
                label = self.class2num[category]
                self.class_nums[label] = 0
                for file1 in files:
                    self.data_x_path.append(os.path.join(root, file1))
                    self.data_y_label.append(label)
                    self.class_nums[label] += 1
            else:
                raise RuntimeError("please check the folder structure!")

        self.StratifiedKFoldData = {}
        skf = StratifiedKFold(n_splits=self.K)
        skf.get_n_splits(self.data_x_path, self.data_y_label)
        print(skf)
        i = 1
        for train_index, test_index in skf.split(self.data_x_path, self.data_y_label):
            X_train, X_test = np.array(self.data_x_path)[train_index], np.array(self.data_x_path)[test_index]
            y_train, y_test = np.array(self.data_y_label)[train_index], np.array(self.data_y_label)[test_index]
            name = f'K{i}'
            self.StratifiedKFoldData[name] = ((X_train, y_train), (X_test, y_test))
            i += 1

    def getKFoldData(self):
        '''
        返回一个字典,字典里共包含K个键值对。
        keys: K1, K2, K3, ....
        values: ((x_train,y_train),(x_test,y_test))  其中的 x_train 包含K-1份,x_test包含1份
        examples:  假如K=5, 用1,2,3,4,5代表5折,
                  则:    x_train(y_train)    x_test(y_test)
                         1,2,3,4                    5
                         1,2,3,5                    4
                         1,2,4,5                    3
                         1,3,4,5                    2
                         2,3,4,5                    1
        '''
        return self.StratifiedKFoldData


class DatasetFromFilename(Dataset):
    # x: a list of image file full path
    # y: a list of image label
    def __init__(self, x, y, transforms=None):
        super(DatasetFromFilename, self).__init__()
        self.x = x
        self.y = y
        if transforms == None:
            self.transforms = ToTensor()
        else:
            self.transforms = transforms

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        img = img.convert("RGB")
        return self.transforms(img), torch.tensor(self.y[idx])


def main():

    # device = ...
    # criterion = ...

    ImageSplit = ImageFolderSplitKFold('.....data_dir', 5)
    DATA = ImageSplit.getKFoldData()
    idx_to_class = ImageSplit.num2class
    my_transforms = transforms.Compose([
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomHorizontalFlip()
    ])

    for k, key in enumerate(DATA.keys()):  # K折交叉验证

        # model = ...
        # model.to(device)
        # optimizer = ...

        (x_train, y_train), (x_test, y_test) = DATA[key]
        training_dataset = DatasetFromFilename(x_train, y_train, transforms=my_transforms)
        test_dataset = DatasetFromFilename(x_test, y_test, transforms=my_transforms)
        train_loader = DataLoader(training_dataset, batch_size=16, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

        for epoch in range(20):
            pass
            # train for one epoch, printing every 10 iterations

            # 训练
            # train_one_epoch(model, criterion, optimizer, train_loader, idx_to_class, device, epoch)

            # 测试
            # evaluate(model, criterion, test_loader, idx_to_class, device)




if __name__ == '__main__':
    main()





评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值