工作原理
-
Dataset:
Dataset
是自己定义的数据集类,它需要实现__len__
和__getitem__
方法。在__getitem__
方法中,你通常返回数据项(例如numpy
数组、列表、字典等)。 -
DataLoader:
DataLoader
从Dataset
中取出数据,并自动处理批处理(batching)、打乱(shuffling)、多进程加载(multi-process loading)等操作。 -
Collate 函数:
DataLoader
使用 collate 函数来将单个数据项合并成一个批次。默认的 collate 函数会将numpy
数组、Python 列表和其他支持的类型自动转换为 PyTorch 张量。
示例
以下是一个完整的示例,展示了如何使用 DataLoader
从 Dataset
中加载数据并自动转换为张量。
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
# 创建一些示例数据
self.data = np.random.randn(100, 3) # 100个样本,每个样本3个特征
self.labels = np.random.randint(0, 2, size=(100,)) # 100个样本,每个样本一个标签
def __len__(self):
# 返回数据集的大小
return len(self.data)
def __getitem__(self, idx):
# 返回数据项(数据和标签)
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 创建数据集实例
dataset = MyDataset()
# 创建DataLoader实例
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历DataLoader
for batch_data, batch_labels in dataloader:
print(f"数据类型: {type(batch_data)}, 形状: {batch_data.shape}")
print(f"标签类型: {type(batch_labels)}, 形状: {batch_labels.shape}")
break # 只展示一个批次的数据
输出
数据类型: <class 'torch.Tensor'>, 形状: torch.Size([10, 3])
标签类型: <class 'torch.Tensor'>, 形状: torch.Size([10])
详细说明
-
自定义数据集:
MyDataset
继承自torch.utils.data.Dataset
,实现了__len__
和__getitem__
方法。在__getitem__
方法中,我们返回了数据项(numpy
数组)和标签。 -
DataLoader: 创建了一个
DataLoader
实例,设置batch_size=10
和shuffle=True
。 -
批处理: 在遍历
DataLoader
时,每个批次的数据和标签会被自动转换为 PyTorch 张量。
自定义 collate 函数
如果需要自定义数据项的合并方式,可以提供自己的 collate 函数:
def my_collate_fn(batch):
data, labels = zip(*batch)
data = torch.tensor(data)
labels = torch.tensor(labels)
return data, labels
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=my_collate_fn)
通过这种方式,可以灵活地控制数据加载和转换过程。如果 DataLoader
默认的行为无法满足需求,提供自定义 collate 函数是一个不错的解决方案。