1、简述
PyTorch中对数据集的描述类都是从torch.utils.data.Dataset继承,然后交给torch.utils.data.DataLoader来处理。
在模型训练时,只操作 torch.utils.data.DataLoader 即可,这样就将数据集的代码和训练的代码分离,方便维护。
2、PyTorch预加载数据集
PyTorch针对图像、音频、文本都提供了常用的内置数据集。这些数据集继承自torch.utils.data.Dataset,并且实现了如下方法,可以传递给torch.utils.data.DataLoader
__getitem__
__len__
数据加载示例:
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers