当我们用 PyTorch 来训练神经网络时,经常需要用到 Dataset
和 DataLoader
这两个类。它们都是 PyTorch 中的数据处理工具,用于读取和处理大量的数据,并将其转换为可供神经网络使用的格式。
Dataset
Dataset
类是一个抽象类,定义了读取数据集的接口方法。我们可以通过继承 Dataset
类,并实现其中的 __len__()
和 __getitem__()
方法来创建自定义的数据集。其中:
__len__()
方法返回数据集中样本的数量;__getitem__()
方法可以根据给定的索引(从 0 开始)获取数据集中对应的一个样本。
下面是一个简单的示例,该示例演示了如何从 CSV 文件中读取数据,并将其转换为 PyTorch 的 Tensor 对象:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
self.X = self.data.iloc[:, :-1].values
self.y = self.data.iloc[:, -1:].values
def __len__(self):
return len(self.data)
def __getitem__(self, index):
X = torch.tensor(self.X[index], dtype=torch.float32)
y = torch.tensor(self.y[index], dtype=torch.float32)
return X, y
在上述代码中,我们首先通过 Pandas 库将 CSV 文件读入到内存中;然后在 __init__()
方法中将数据集的 X 和 y 值分别存储到 self.X
和 self.y
中;最后在 __getitem__()
方法中,根据给定的索引获取对应的样本,并将其转换为 PyTorch 的 Tensor 对象。
DataLoader
DataLoader
类是一个可迭代的对象,可以用于批量加载数据集。它可以从 Dataset
对象(或其他类似于数组的对象)中获取数据,并在训练过程中生成随机批次的数据。其核心思想是使用多个 worker 进程并行地预取和处理数据,以加快数据读取速度,并在内存中组织成批次形式。
下面是一个简单的示例:
from torch.utils.data import DataLoader
# 创建 MyDataset 对象
my_dataset = MyDataset('data.csv')
# 创建 DataLoader 对象
batch_size = 32
dataloader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)
# 遍历 DataLoader 对象
for i, (X_batch, y_batch) in enumerate(dataloader):
# 在此处进行模型训练或测试
pass
在上述代码中,我们首先创建了一个 MyDataset
对象,并将其传递给了 DataLoader
构造函数中。我们还指定了 batch_size
参数,表示每个批次中包含的样本数量。设置 shuffle=True
表示每个 epoch 开始时,都会对数据集进行重新洗牌,以增加训练的随机性。
在实际训练过程中,我们可以通过遍历 DataLoader
对象来获取随机批次的数据,并将其用于模型训练或测试。在遍历 DataLoader
对象时,每个批次的数据都以元组 (X_batch, y_batch)
的形式返回,其中 X_batch
和 y_batch
分别表示一个批次中的特征和标签。
补充
下面介绍一下 DataLoader
是如何从 Dataset
对象中获取数据,并生成随机批次的:
- 在创建
DataLoader
对象时,需要传入一个Dataset
对象,表示要加载的数据集。 DataLoader
会依次遍历Dataset
中的每个样本,并将其添加到一个存储样本的列表中。此时,如果设置了num_workers
参数,则会启动对应数量的 worker 进程,每个进程都会异步地预取和处理数据,并将结果添加到列表中。- 当列表中的元素数量等于
batch_size
参数指定的值时,DataLoader
就会将这些元素封装成一个 batch,并返回给用户。同时,如果设置了shuffle
参数为 True,还会随机打乱列表中的元素顺序,以增加数据的随机性。 - 如果
Dataset
中的样本数量不足以填满一个 batch,那么DataLoader
会将剩余样本封装成一个小于batch_size
的 batch 并返回。另外,如果drop_last
参数设置为 True,则会舍弃最后一个小于batch_size
的 batch,以保证每个 batch 中样本数量一致。
通过这种方式,DataLoader
可以高效、方便地从数据集中加载数据,并将其封装成批次形式,供模型训练使用。