"""
author:JIN
"""
from torchvision.datasets.vision import VisionDataset
from sklearn.model_selection import StratifiedShuffleSplit
from PIL import Image
import os
import os.path
import warnings
warnings.filterwarnings('ignore')
class Custom_split_dataloader(VisionDataset):
def __init__(self, root, test_size=None,train=True,transform=None,
):
self.root=root
classes, class_to_idx = self._find_classes(self.root)
samples = self._make_dataset(self.root, class_to_idx)
self.loader = self.default_loader
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
self.transform=transform
self.train=train
self.test_size=test_size
[self.samples_train_index, self.samples_val_index]= list(StratifiedShuffleSplit(
n_splits=1, test_size= self.test_size,random_state=724
).split(self.samples, self.targets))[0]
self.samples_train = [self.samples[i] for i in self.samples_train_index]
self.samples_val = [self.samples[i] for i in self.samples_val_index]
self.patch_dir = root
self.transform = transform
if self.train==True:
self.imgs = self.samples_train
else:
self.imgs = self.samples_val
self.samples=self.imgs
def pil_loader(self,path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(self,path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return self.pil_loader(path)
def default_loader(self,path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return self.accimage_loader(path)
else:
return self.pil_loader(path)
def _find_classes(self, dir):
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
def _make_dataset(self, directory, class_to_idx):
instances = []
directory = os.path.expanduser(directory)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
item = path, class_index
instances.append(item)
return instances
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
sample = self.transform(sample)
return sample, target
def __len__(self):
return len(self.samples)
使用方法,将训练集按照5:1比例划分为训练和验证
print("====>load traindatset ")
train_data_root = 'dataset/train/'
train_dataset = Custom_split_dataloader(train_data_root, transform=data_transform,test_size=0.2,train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize,
shuffle=True,drop_last=True,num_workers=args.threads)
print("====>load valdatset ")
val_dataset = Custom_split_dataloader(train_data_root, transform=data_transform,test_size=0.2,train=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize,
shuffle=True,drop_last=True,num_workers=args.threads)