Pytorch 自定义数据加载器

在前面,我们使用Lenet训练的都是使用默认数据加载器加载特定的数据,本章节我们分析下怎么使用自定义的data.Dataset加载数据

口罩数据集

数据分为两类,mask和no_mask,数据集全部来自于网络

  • mask数据

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fNt7ei6M-1610790347001)(https://i.loli.net/2021/01/16/h1TvjaeJ6QzOCR5.png)]

  • no_mask数据

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-GFz7JWpW-1610790347003)(https://i.loli.net/2021/01/16/KlawhSNpDURXFEW.png)]

数据目录结构为:


├── mask
│   ├── test
│   │   ├── mask_0
│   │   └── no_mask_1
│   └── train
│       ├── mask_0
│       └── no_mask_1

数据集下载
链接:https://pan.baidu.com/s/10UcSznNbaUJn8EsVeQc2yg
提取码:2kjb

自定义数据集加载器

主要思路是,通过分割train,和test下各个目录下的图片目录,解析出分类名称和id,例如

mask_0       --->         class:mask     class_id:0
no_mask_1    --->         class:no_mask  class_id:1

关键函数如下

find_classes

def find_classes_with_id(dir:str) -> Tuple[List[str], Dict[str, int]]:
    classes = []
    class_ids = {}
    class_dirs = [d.name for d in os.scandir(dir) if d.is_dir]
    for class_dir in class_dirs:
        split_list = class_dir.split('_')
        if len(split_list)==1:
            msg = "{} form is not right, it should be [classname_id]!!\n".format(class_dir)
            raise RuntimeError(msg)
        class_id = split_list[len(split_list)-1]
        if not class_id.isdigit():
            msg = "{} is not end with '_digit' !\n".format(class_dir)
            raise RuntimeError(msg)
        finaly_split_s = '_'+class_id
        class_name = class_dir.split(finaly_split_s)[0]
        classes.append(class_name)
        class_ids[class_name] = class_id

    return classes, class_ids

返回的结果如下

['no_mask', 'mask']
{'no_mask': '1', 'mask': '0'}

make_dataset

def make_dataset(
    directory: str,
    class_to_idx: Dict[str, int],) -> List[Tuple[str, int]]:

    instances = []#struct

    if not os.path.isdir(directory):
        raise ValueError("Image not dir!!!")

    image_count = 0
    for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_class_dir  = target_class+'_'+str(class_index)
            target_dir = os.path.join(directory, target_class_dir)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

    return instances

解析每一个分类目录下图片,返回每个图片的路径和分类id结果,每个item的格式如下

('./datas/mask/train/mask_0/08021120510000090.jpg', 0)


定义数据加载器 CommonData

然后,集成继承data.Dataset实现CommonData加载器

class CommonData(data.Dataset):

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.mean = (0.479, 0.385, 0.352)
        self.std = (0.194, 0.171, 0.165)
        #获取类别和类别id
        classes, class_to_idx = find_classes_with_id(root_dir)
        samples = make_dataset(root_dir, class_to_idx)

        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(root_dir)
            raise RuntimeError(msg)
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]
        print('------SmokeData[%s]',root_dir)
        print('classes[%s]'% self.classes)
        print('class_to_idx[%s]'%self.class_to_idx)
        self.count = 0
        #print('targets[%s]'%self.targets)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        sample = default_loader(path)
        self.count = self.count+1
        #sample.save("./"+str(self.count)+'.jpg')
        if self.transform is not None:
            sample = self.transform(sample)

        return sample, target

主要实现__len__和__getitem__方法,分别返回数据集的长度,和遍历获取每个数据的图片和分类id。

使用数据加载器

def getClassfierDataset(train_dir, test_dir, dataresize=64):
        train_transforms = transforms.Compose([
                                               transforms.RandomRotation(20),
                                               transforms.Resize((dataresize,dataresize)),
                                               transforms.RandomHorizontalFlip(0.5), 
                                               #transforms.ColorJitter(brightness=[0.8,1.3], contrast=[0.8,1.3], saturation=[0.8,1.3], hue=0.2),
                                               transforms.ToTensor(), 
                                               transforms.Normalize((0.479, 0.385, 0.352),
                                                                    (0.194, 0.171, 0.165))])

        test_transforms = transforms.Compose([
                                            transforms.Resize((dataresize,dataresize)),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.479, 0.385, 0.352),
                                                            (0.194, 0.171, 0.165))])
        tain_smoke_data = CommonData(train_dir, transform = train_transforms)
        test_smoke_data = CommonData(test_dir, transform = test_transforms)
        # 使用预处理格式加载图像
        #train_data = datasets.ImageFolder(train_dir,transform = train_transforms)
        #valid_data = datasets.ImageFolder(test_dir,transform = test_transforms)

        # 创建三个加载器,分别为训练,验证,测试,将训练集的batch大小设为64,即每次加载器向网络输送64张图片
        #shuffle 随机打乱,网络更容易学习不同的特征,更容易收敛
        print('load dataset......')
        trainloader = torch.utils.data.DataLoader(tain_smoke_data,batch_size = 64,shuffle = True)
        validloader = torch.utils.data.DataLoader(test_smoke_data,batch_size = 64)

        return trainloader,validloader

通过以上步骤,我们就获取了训练集和验证集的数据加载器,然后训练的时候使用方法如下

train_loader,test_loader=getClassfierDataset(train_dir, test_dir, self.input_shape)

以上就是自定义pytroch数据加载器的具体实现

源码参考

Pytorch/datasets/classifier/commonDataset.py

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch允许您创建自定义数据集以便于加载和处理您自己的数据。以下是一个简单的示例来创建自定义数据集: 首先,您需要导入必要的库: ```python import torch from torch.utils.data import Dataset ``` 然后,创建一个继承自`Dataset`类的自定义数据集类,并实现以下方法: - `__init__`:初始化数据集,例如加载数据或设置转换。 - `__len__`:返回数据集的大小。 - `__getitem__`:根据给定的索引返回一个样本。 下面是一个示例,假设您有一组图像数据和相应的标签: ```python class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] label = self.labels[index] # 在这里进行必要的数据转换 return sample, label ``` 在上面的示例中,`data`是图像数据的列表,`labels`是相应的标签的列表。然后,您可以在`__getitem__`方法中执行必要的数据转换,例如将图像转换为张量或应用任何其他预处理步骤。 要使用自定义数据集,您可以创建一个实例并将其传递给`DataLoader`类: ```python # 假设您有图像数据和标签 data = [...] # 图像数据列表 labels = [...] # 标签列表 # 创建自定义数据集实例 dataset = CustomDataset(data, labels) # 创建数据加载 dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 现在,您可以使用`dataloader`来迭代加载批量的数据,并在训练模型时使用它们。 这只是一个简单的示例,您可以根据您的需求进行更多的自定义和扩展。希望这可以帮助到您!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值