自定义ImageFolder和DataLoader

需求:想要给dataset的每一个图像加上路径信息。

train_dl = DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn)

DataLoader根据batch_size大小,每次得到一组长度为batch_size的index,根据index调用dataset.__getitem__()方法,将返回的结果存在列表中,最后将列表作为输入参数传递给collate_fn,collate_fn的返回值即为每次next(iter(train_dl))的返回值。

例如,如果batch_size=3,那么index可能为5,27,49,则将列表

batch=[dataset.__getitem__(5),dataset.__getitem__(27),dataset.__getitem__(49)]

输入到collate_fn(batch)中,随后,collate_fn(batch)的返回值即为next(iter(train_dl))的返回值。由此可知,为了使得dataset和train_dl有路径信息,需要重写__getitem__()方法。于是,我们创建一个名为ImageFolderWithID的类,继承自ImageFolder

class ImageFolderWithID(ImageFolder):
    def __init__(self,root,img_transforms):
        super(ImageFolderWithID, self).__init__(root,transform=img_transforms)
    def __getitem__(self, item):
        return super(ImageFolderWithID, self).__getitem__(item),self.samples[item][0]
    def __len__(self):
        return super(ImageFolderWithID, self).__len__()

imgset = ImageFolderWithID(root_path,data_transforms)

这样一来,当调用imgset.__getitem__(5)的时候,返回的元组的第一个元素是原来的既包含图像信息,又包含图像类别的子元组;第二个元素是该图片的路径,即

imgset.__getitem__(5)=( ( 图片矩阵,图片类别 ) , 图片路径 )

为了在训练的时候只把矩阵和类别提取出来,而忽略路径,需要重写DataLoader的collate_fn函数

def my_collate_fn(the_batch):
    origin_batch, path_list = [], []
    for item in the_batch:
        origin_batch.append(item[0])
        path_list.append(item[1])
    img_x, label_y = default_collate(origin_batch)
    return img_x, label_y, path_list

train_ds, test_ds = random_split(imgset, [0.8, 0.2])
train_dl = DataLoader(dataset=train_ds,
                      batch_size=batch_size, 
                      shuffle=True, 
                      collate_fn=my_collate_fn)

前面提到过,the_batch是

the_batch=[dataset.__getitem__(5),dataset.__getitem__(27),dataset.__getitem__(49)]

而每一个__getitem__的返回值是

imgset.__getitem__(5)=( ( 图片矩阵,图片类别 ) , 图片路径 )

因此my_collate_fn先将每一个__getitem__()的值提取出前面一个元组,将它们送给默认的collate函数:default_collate得到我们训练需要的img_x和label_y,至于path_list,在训练的时候解包用_字符占位即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
datasets 和 dataloaderPyTorch 中用于处理和加载数据的两个重要模块。 在 PyTorch 中,datasets 用于存储和处理数据集,例如图像、文本等。PyTorch 提供了许多内置的 datasets,如 torchvision 中的 ImageFolder 和 MNIST,也可以自定义 datasets。 下面是使用 torchvision.datasets.ImageFolder 加载图像数据集的示例代码: ```python import torchvision.datasets as datasets # 定义数据集路径 data_dir = 'path/to/dataset' # 创建 ImageFolder 数据集 dataset = datasets.ImageFolder(data_dir) # 获取数据集的长度 dataset_size = len(dataset) # 获取类别标签 class_labels = dataset.classes # 可以通过索引访问数据集中的样本 sample, label = dataset[0] # 可以通过迭代器遍历整个数据集 for sample, label in dataset: # 在这里对样本进行处理/转换 pass ``` 接下来,我们可以使用 dataloader 对数据集进行批量加载和预处理。dataloader 可以方便地将数据集划分为小批量样本,进行数据增强或标准化等操作。 下面是使用 torch.utils.data.DataLoader 对数据集进行批量加载的示例代码: ```python import torch.utils.data as data # 定义批量大小和多线程加载数据的工作进程数 batch_size = 32 num_workers = 4 # 创建 dataloader dataloader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) # 可以通过迭代器遍历整个数据集的小批量样本 for batch_samples, batch_labels in dataloader: # 在这里对小批量样本进行处理/转换 pass ``` 在上面的示例中,我们创建了一个 dataloader,并指定了批量大小和加载数据的工作进程数。`shuffle=True` 表示每个 epoch 都会对数据进行随机打乱,以增加数据的多样性。 通过使用 datasets 和 dataloader,我们可以方便地加载和处理各种类型的数据集,并应用各种预处理操作。这些模块的使用可以大大简化数据加载和处理的过程,提高代码的可读性和可维护性。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值