在图像分类任务中,有时候需要把图像数据分为K折做交叉验证,以评估模型的性能。但是在pytorch中并没有相应的代码,因此自己写了一个框架,以便调用
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from PIL import Image
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
import numpy as np
class ImageFolderSplitKFold:
'''
path: 保存图像的文件
K: 划分的折数
'''
def __init__(self, path, K=5):
self.path = path
self.K = K
self.class2num = {}
self.num2class = {}
self.class_nums = {}
self.data_x_path = []
self.data_y_label = []
for root, dirs, files in os.walk(path):
if len(files) == 0 and len(dirs) > 1:
for i, dir1 in enumerate(dirs):
self.num2class[i] = dir1
self.class2num[dir1] = i
elif len(files) > 1 and len(dirs) == 0:
category = ""
for key in self.class2num.keys():
if key in root:
category = key
break
label = self.class2num[category]
self.class_nums[label] = 0
for file1 in files:
self.data_x_path.append(os.path.join(root, file1))
self.data_y_label.append(label)
self.class_nums[label] += 1
else:
raise RuntimeError("please check the folder structure!")
self.StratifiedKFoldData = {}
skf = StratifiedKFold(n_splits=self.K)
skf.get_n_splits(self.data_x_path, self.data_y_label)
print(skf)
i = 1
for train_index, test_index in skf.split(self.data_x_path, self.data_y_label):
X_train, X_test = np.array(self.data_x_path)[train_index], np.array(self.data_x_path)[test_index]
y_train, y_test = np.array(self.data_y_label)[train_index], np.array(self.data_y_label)[test_index]
name = f'K{i}'
self.StratifiedKFoldData[name] = ((X_train, y_train), (X_test, y_test))
i += 1
def getKFoldData(self):
'''
返回一个字典,字典里共包含K个键值对。
keys: K1, K2, K3, ....
values: ((x_train,y_train),(x_test,y_test)) 其中的 x_train 包含K-1份,x_test包含1份
examples: 假如K=5, 用1,2,3,4,5代表5折,
则: x_train(y_train) x_test(y_test)
1,2,3,4 5
1,2,3,5 4
1,2,4,5 3
1,3,4,5 2
2,3,4,5 1
'''
return self.StratifiedKFoldData
class DatasetFromFilename(Dataset):
def __init__(self, x, y, transforms=None):
super(DatasetFromFilename, self).__init__()
self.x = x
self.y = y
if transforms == None:
self.transforms = ToTensor()
else:
self.transforms = transforms
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = Image.open(self.x[idx])
img = img.convert("RGB")
return self.transforms(img), torch.tensor(self.y[idx])
def main():
ImageSplit = ImageFolderSplitKFold('.....data_dir', 5)
DATA = ImageSplit.getKFoldData()
idx_to_class = ImageSplit.num2class
my_transforms = transforms.Compose([
transforms.RandomCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomHorizontalFlip()
])
for k, key in enumerate(DATA.keys()):
(x_train, y_train), (x_test, y_test) = DATA[key]
training_dataset = DatasetFromFilename(x_train, y_train, transforms=my_transforms)
test_dataset = DatasetFromFilename(x_test, y_test, transforms=my_transforms)
train_loader = DataLoader(training_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)
for epoch in range(20):
pass
if __name__ == '__main__':
main()