import os
import numpy as np
from PIL import Image
from sklearn.model_selection import KFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
class ImageDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_paths = []
self.labels = []
folder_to_label = {'1': 0, '2': 1}
for label_folder, label in folder_to_label.items():
class_dir = os.path.join(image_dir, label_folder)
if os.path.isdir(class_dir):
for file_name in os.listdir(class_dir):
file_path = os.path.join(class_dir, file_name)
if file_path.endswith(('.png', '.jpg', '.jpeg')):
self.image_paths.append(file_path)
self.labels.append(label)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label