pytorch中Dataset、Dataloader、Sampler、collate_fn相互关系和使用说明

提示:本文文字部分80%以上由大模型生成,人工做了校正。

参考: https://blog.csdn.net/Chinesischguy/article/details/103198921

参考: https://zhuanlan.zhihu.com/p/76893455

参考:https://blog.csdn.net/lilai619/article/details/118784730

参考:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

        本博客旨在介绍PyTorch深度学习框架中Dataset、Dataloader、Sampler、collate_fn组件之间相互关系,以及如何自定义各组件。这些组件是深度学习项目中不可或缺的组成部分,对于理解和使用PyTorch框架进行深度学习任务至关重要。

        在PyTorch深度学习框架中,Dataset、Dataloader、Sampler和collate_fn是数据加载和处理过程中非常重要的组成部分。它们之间的调用关系如下:

  1. Dataset:定义了数据集的接口,用于读取和处理数据。通常情况下,Dataset是从文件或数据库中读取数据的集合,它可以对数据进行预处理、增强等操作,并返回一个可迭代的对象,用于后续的数据加载过程。

  2. Dataloader:实现了数据集的批量加载功能。Dataloader可以根据Dataset返回的可迭代对象,将数据分成多个batch,并按照指定的采样方式(如随机采样、分层采样等)进行采样。同时,Dataloader还可以自动调整batch size、设置数据加载器状态等。

  3. Sampler:定义了数据集中每个batch所包含的数据的位置索引。通常情况下,Sampler是在数据加载之前设置的一个对象,它可以根据用户指定的要求(如按照类别、标签等)对数据集进行采样,并返回每个batch所包含的数据的位置索引。

  4. collate_fn:用于将一个batch中的数据进行拼接和整理。通常情况下,collate_fn是在Dataloader创建时设置的一个函数,它可以根据Dataset返回的可迭代对象和Sampler返回的位置索引,将不同长度的输入数据转换为统一的形状,并返回一个新的tensor作为batch的数据。

        综上所述,Dataset、Dataloader、Sampler和collate_fn之间是相互协作的,它们共同完成了数据加载和处理的过程。具体来说,Dataset提供了数据集的接口和一些基本的操作;Dataloader实现了数据的批量加载和一些高级的功能;Sampler根据用户指定的要求对数据集进行采样;collate_fn负责将不同长度的输入数据转换为统一的形状。本文将讨论这四个组件的使用方法,并提供一些自定义各组件的技术实践经验。我们将从以下几个方面来探讨:

        1. Dataset的使用方法和自定义技巧;

        2. Sampler的使用方法和自定义技巧;

        3. collate_fn的使用方法和自定义技巧。

DataLoader, Sampler, Dataset三者的关系

        1. Sampler提供indicies

        2. Dataset根据indicies提供data,使用__getitem__方法

        3. DataLoader将上面两个组合起来,提供最终的batch训练数据,其中collate_fn可以对batch中的数据做额外的处理

自定义Dataset

        在PyTorch中,可以通过继承torch.utils.data.Dataset类来自定义数据集(Dataset)类。自定义的数据集类可以包含自己的数据加载和预处理方法,以及一些额外的元数据。

import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter


class MyDataset(Dataset):
    """
        加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
    """

    def __init__(self, filenames, labels, transforms_pipeline=None):
        super().__init__()
        # 所有图像的路径列表
        self.filenames = filenames
        # 所有图片对应的label标签编号,从0开始
        self.labels = labels
        # 图像预处理
        self.transforms_pipeline = transforms_pipeline

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filepath = self.filenames[idx]
        img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
        if self.transforms_pipeline:
            img = self.transforms_pipeline(img)
        return img, self.labels[idx]

        以上代码自定义了一个Dataset类用于加载训练数据,训练数据中cat和dog目录下分别存储的是猫和狗的图片。

         使用以下代码片段测试自定义的Dataset数据加载情况:

transforms_pipeline = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
    ]
)

# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = "XXX"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
    # print(image_path, label)
    filenames.append(image_path)
    labels.append(label)
print(filenames, labels)

# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
img, label = my_dataset[10]
print(img.shape, label)

自定义Sampler

        在PyTorch中,可以通过继承torch.utils.data.Sampler类来自定义采样器(Sampler)类。自定义的采样器类可以控制数据集中每个样本的采样方式,例如随机采样、分块采样等。

class MySampler(Sampler):
    """
        自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
    """

    def __init__(self, labels):
        self.labels = np.array(labels)
        self.image_ids = []

    def __iter__(self):
        """
            在每个batch中包含的每个类别的数量相等
        :return:
        """
        indices = []
        counter = Counter(self.labels)
        # 统计数据量最多的类别
        most_common = counter.most_common(1)[0][1]
        # 统计每张图片在filenames这个列表中对应的索引编号
        for c in range(len(counter)):
            indices.append(np.where(self.labels == c)[0].tolist())

        # 所有类别通过复制的方式与最多的类别对齐
        for indice in indices:
            if len(indice) < most_common:
                indice.extend(random.choices(indice, k=most_common - len(indice)))
            random.shuffle(indice)

        # 依次从所有类别中分别取一张图片组成batch
        for ids in zip(*indices):
            self.image_ids.extend(list(ids))

        return iter(self.image_ids)

    def __len__(self):
        return len(self.image_ids)

        以上自定义Sampler控制在返回训练样本编号的逻辑,使得每个batch中的各类别数据量相等,Sampler返回训练样本的编号,然后使用Dataset的__getitem__方法取出对应的样本。

        使用以下代码片段测试自定义的Sampler的数据采样情况:

my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
sample_labels = []
for x in my_sampler:
    print(x)
    sample_labels.append(my_sampler.labels[x])
print(sample_labels)
print(len(my_sampler))

自定义collate_fn函数

        在PyTorch中,自定义collate_fn函数可以用于对数据集中的数据进行整合和处理。当使用自定义采样器(Sampler)加载数据时,collate_fn函数会被自动调用来整合每个batch的数据。

def collate_fn(batch_data):
    """
        对batch中的图像使用mixup,并返回mixup之后的结果
    :param batch_data:
    :return:
    """

    def mixup_data(x, y, alpha=1.0, use_cuda=False):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]

        if use_cuda:
            index = torch.randperm(batch_size).cuda()
        else:
            index = torch.randperm(batch_size)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    batch_img = []
    batch_label = []
    for img, label in batch_data:
        batch_img.append(img)
        batch_label.append(label)

    batch_img = torch.stack(batch_img, dim=0)
    batch_label = torch.tensor(batch_label)
    # print(batch_img.shape, batch_label.shape)

    batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
    return batch_img, batch_label_a, batch_label_b, batch_lam

        在以上自定义collate_fn函数中,我们在每个batch批量样本之间使用mixup数据增强,并返回mixup之后的增强数据以及对应的标签和参数。

自定义Dataset、Sampler、collate_fn,以及使用Dataloader的完整代码

# coding:utf-8

import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter


class MyDataset(Dataset):
    """
        加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
    """

    def __init__(self, filenames, labels, transforms_pipeline=None):
        super().__init__()
        # 所有图像的路径列表
        self.filenames = filenames
        # 所有图片对应的label标签编号,从0开始
        self.labels = labels
        # 图像预处理
        self.transforms_pipeline = transforms_pipeline

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filepath = self.filenames[idx]
        img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
        if self.transforms_pipeline:
            img = self.transforms_pipeline(img)
        return img, self.labels[idx]


def collate_fn(batch_data):
    """
        对batch中的图像使用mixup,并返回mixup之后的结果
    :param batch_data:
    :return:
    """

    def mixup_data(x, y, alpha=1.0, use_cuda=False):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]

        if use_cuda:
            index = torch.randperm(batch_size).cuda()
        else:
            index = torch.randperm(batch_size)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    batch_img = []
    batch_label = []
    for img, label in batch_data:
        batch_img.append(img)
        batch_label.append(label)

    batch_img = torch.stack(batch_img, dim=0)
    batch_label = torch.tensor(batch_label)
    # print(batch_img.shape, batch_label.shape)

    batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
    return batch_img, batch_label_a, batch_label_b, batch_lam


class MySampler(Sampler):
    """
        自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
    """

    def __init__(self, labels):
        self.labels = np.array(labels)
        self.image_ids = []

    def __iter__(self):
        """
            在每个batch中包含的每个类别的数量相等
        :return:
        """
        indices = []
        counter = Counter(self.labels)
        # 统计数据量最多的类别
        most_common = counter.most_common(1)[0][1]
        # 统计每张图片在filenames这个列表中对应的索引编号
        for c in range(len(counter)):
            indices.append(np.where(self.labels == c)[0].tolist())

        # 所有类别通过复制的方式与最多的类别对齐
        for indice in indices:
            if len(indice) < most_common:
                indice.extend(random.choices(indice, k=most_common - len(indice)))
            random.shuffle(indice)

        # 依次从所有类别中分别取一张图片组成batch
        for ids in zip(*indices):
            self.image_ids.extend(list(ids))

        return iter(self.image_ids)

    def __len__(self):
        return len(self.image_ids)


## 测试自定义Sampler
# my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
# sample_labels = []
# for x in my_sampler:
#     print(x)
#     sample_labels.append(my_sampler.labels[x])
# print(sample_labels)
# print(len(my_sampler))


transforms_pipeline = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
    ]
)

# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = r"C:\WorkDir\PythonWorkspace\MusicRecognition\mixup-cifar10-main\data\cat_and_dog"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
    # print(image_path, label)
    filenames.append(image_path)
    labels.append(label)
print(filenames, labels)

# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
# img, label = my_dataset[10]
# print(img.shape, label)

# 使用自定义collate_fn函数,在每个batch中进行mixup图片增强,并返回增强后的图片数据、标签、以及mixup系数
dataloader = DataLoader(
    my_dataset,
    batch_size=8,  # batch_size要能整除类别数
    shuffle=False,  # 使用sampler时,shuffle参数要设置为False
    sampler=MySampler(labels),  # 自定义Sampler,返回的batch中每种类别的数量相等
    batch_sampler=None,
    collate_fn=collate_fn  # 自定义collate_fn,其中执行mixup数据增强
)

for batch_img, batch_label_a, batch_label_b, batch_lam in dataloader:
    print(batch_img.shape, batch_label_a.shape, batch_label_b.shape, batch_lam)
    # batch中包含每个类别的数量相等,猫和狗都是4张
    # {0: 4, 1: 4}
    print(Counter(batch_label_a.detach().cpu().numpy().tolist()))
    break

for idx, img in enumerate(batch_img):
    plt.imshow(img.permute(1, 2, 0).int().clamp(min=0, max=255).detach().cpu().numpy())
    plt.show()

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值