Pytorch 修改Imageloader支持划分验证集和训练集

本文介绍如何使用Custom_split_dataloader类对PyTorch数据集进行分隔,实现训练集和验证集按照5:1的比例划分,并提供数据加载和预处理方法。适用于深度学习项目中数据集的组织与管理。
摘要由CSDN通过智能技术生成
"""
author:JIN
"""
from torchvision.datasets.vision import VisionDataset
from sklearn.model_selection import StratifiedShuffleSplit
from PIL import Image
import os
import os.path
import warnings
warnings.filterwarnings('ignore')

class Custom_split_dataloader(VisionDataset):
    def __init__(self, root, test_size=None,train=True,transform=None,
                 ):
        self.root=root
        classes, class_to_idx = self._find_classes(self.root)
        samples = self._make_dataset(self.root, class_to_idx)
        self.loader = self.default_loader
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]
        self.transform=transform
        
        self.train=train
        self.test_size=test_size
        [self.samples_train_index, self.samples_val_index]= list(StratifiedShuffleSplit(
                                                                n_splits=1, test_size= self.test_size,random_state=724
                                                                ).split(self.samples, self.targets))[0]
        self.samples_train = [self.samples[i] for i in self.samples_train_index]
        self.samples_val = [self.samples[i] for i in self.samples_val_index]
        self.patch_dir = root
        self.transform = transform

        if self.train==True:
           self.imgs = self.samples_train
        else:
           self.imgs = self.samples_val
        self.samples=self.imgs
    def pil_loader(self,path):
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')
    
    
    def accimage_loader(self,path):
        import accimage
        try:
            return accimage.Image(path)
        except IOError:
            # Potentially a decoding problem, fall back to PIL.Image
            return self.pil_loader(path)
    
    
    def default_loader(self,path):
        from torchvision import get_image_backend
        if get_image_backend() == 'accimage':
            return self.accimage_loader(path)
        else:
            return self.pil_loader(path)
        
    def _find_classes(self, dir):
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx
    
    def _make_dataset(self, directory, class_to_idx):
        instances = []
        directory = os.path.expanduser(directory)
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    item = path, class_index
                    instances.append(item)
        return instances   
 
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        sample = self.transform(sample)
        return sample, target

    def __len__(self):
        return len(self.samples)
    










使用方法,将训练集按照5:1比例划分为训练和验证

print("====>load traindatset ")
train_data_root = 'dataset/train/'    
train_dataset = Custom_split_dataloader(train_data_root, transform=data_transform,test_size=0.2,train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchSize,
                                           shuffle=True,drop_last=True,num_workers=args.threads)

print("====>load valdatset ")
val_dataset = Custom_split_dataloader(train_data_root, transform=data_transform,test_size=0.2,train=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize,
                                           shuffle=True,drop_last=True,num_workers=args.threads)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值