1.DataLoder和Dataset
DataLoader是PyTorch中用于高效加载数据的重要组件,它主要与用户自定义的Dataset类配合使用,以支持模型训练或验证过程中的数据迭代。以下是使用DataLoader的基本指南,包括其输入、输出及使用要求:
输入要求
dataset (torch.utils.data.Dataset实例): 这是最基本的输入,代表了数据集。你需要先定义一个继承自torch.utils.data.Dataset的类,并实现__len__和__getitem__方法。__len__返回数据集的大小,__getitem__根据索引返回数据样本及其标签(如果有)。
batch_size (int, 可选): 每个批次的数据样本数量。这是训练期间模型一次处理的数据量。
shuffle (bool, 可选): 是否在每个epoch开始时随机打乱数据集。通常在训练时设为True,验证或测试时设为False。
num_workers (int, 可选): 使用多少个子进程来加速数据加载。适合数据集较大,I/O成为瓶颈的情况。
collate_fn (callable, 可选): 用于整理一个批次内的数据样本,使其能被模型接受。例如,列表的列表转换成张量。
sampler 和 batch_sampler (可选): 自定义采样策略,控制数据的读取顺序。
输出
迭代器 (DataLoader实例本身是一个迭代器): 在迭代过程中,每次调用迭代器(如在for data in dataloader:循环中),它会返回一个批次的数据,通常是(data_batch, label_batch)的形式,其中data_batch和label_batch是张量,分别包含了该批次所有样本的数据和标签。
使用示例from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定义数据预处理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 加载数据集 dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 创建DataLoader dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4) # 使用DataLoader进行训练或验证 for epoch in range(num_epochs): for inputs, labels in dataloader: # 训练或验证步骤 # inputs 和 labels 是当前批次的数据和标签 pass from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定义数据预处理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 加载数据集 dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 创建DataLoader dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4) # 使用DataLoader进行训练或验证 for epoch in range(num_epochs): for inputs, labels in dataloader: # 训练或验证步骤 # inputs 和 labels 是当前批次的数据和标签 pass
(这段代码是在准备MNIST数据集的预处理过程,具体来说:
定义数据预处理:transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
这行代码定义了一个预处理流水线,包含两个操作:
transforms.ToTensor(): 这个操作将数据从PIL Image或numpy数组格式转换为PyTorch张量。通常,图像数据的通道顺序是RGB,但MNIST是灰度图像,所以转换后张量的形状将是[C, H, W],其中C=1,H和W是高度和宽度。
transforms.Normalize(mean=(0.5,), std=(0.5,)): 这个操作对数据进行归一化,使得每个通道的像素值的均值为0.5,标准差为0.5。对于MNIST数据集,由于它是灰度图像,只有一个通道,所以均值和标准差都只有一项。这有助于模型更快地收敛。
加载数据集:dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
这行代码加载了MNIST数据集:
root='./data': 指定数据集存储的位置。如果文件夹不存在,PyTorch会尝试下载数据集到该位置。
train=True: 表示加载训练集。如果你想加载测试集,可以设置为False。
download=True: 如果数据集不在root指定的路径下,自动下载数据集。
transform=transform: 将之前定义的预处理流水线应用到加载的数据集上。这意味着每个样本在被加载时都会经过这两个预处理步骤:转为张量并进行归一化。
因此,这段代码的作用是下载并预处理MNIST数据集,将其转换为张量,并对每个像素值进行归一化,以便后续模型的训练。)
注意事项
确保num_workers设置合理,过多的工作进程可能会因为资源竞争而降低效率。
对于图像数据,使用transforms.Compose组合多个变换是常见的做法。
如果数据集很大,考虑使用IterableDataset代替MapDataset,特别是在流式数据或无限数据集的情况下。
collate_fn可以自定义以处理不规则数据结构,例如不同长度的序列数据
断点