译者:BXuan694
所有的数据集都是torch.utils.data.Dataset
的子类, 即:它们实现了__getitem__
和__len__
方法。因此,它们都可以传递给torch.utils.data.DataLoader
,进而通过torch.multiprocessing
实现批数据的并行化加载。例如:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
目前为止,收录的数据集包括:
数据集
- MNIST
- Fashion-MNIST
- EMNIST
- COCO
- LSUN
- ImageFolder
- DatasetFolder
- Imagenet-12
- CIFAR
- STL10
- SVHN
- PhotoTour
- SBU
- Flickr
- VOC
以上数据集的接口基本上很相近。它们至少包括两个公共的参数transform
和target_transform
,以便分别对输入和和目标做变换。