pytorch torchvision.datasets

torchvision 库是服务于pytorch深度学习框架的,用来生成图片,视频数据集,和一些流行的模型类和预训练模型. 

torchvision.datasets

所有数据集都是 torch.utils.data.dataset 的子类,也就是说,它们都实现了 __getitem__ 和 __len__ 方法。因此,它们都可以传递给 torch.utils.data.dataloader,后者可以使用 torch.multiprocessing workers 并行加载多个样本。例如:

imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=args.nThreads)

可获得的数据集如下:

目录

torchvision.datasets

MNIST

Fashion-MNIST

KMNIST

EMNIST

FakeData

COCO

Captions

Detection

LSUN

​ImageFolder

DatasetFolder

Imagenet-12

CIFAR

STL10

SVHN

PhotoTour

SBU

Flickr

VOC

Cityscapes

所有数据集都有几乎相似的API,都有两个共同的参数:transform 和 target_transform 分别对 input 和 target 进行转换 。

MNIST

CLASS torchvision.datasets.MNIST(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

0-9手写数字 数据集。

Parameters:
  • root (string) –  存在 mnist/processed/training.pt 和 mnist/processed/test.pt 的数据集根目录。
  • train (booloptional) – 如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
  • download (booloptional) – 如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
  • transform (callableoptional) –  一个函数/转换,它接收PIL图像并返回转换后的版本。 例如,transforms.RandomCrop
  • target_transform (callableoptional) – 接收目标并对其进行转换的函数/转换。

Fashion-MNIST

CLASS torchvision.datasets.FashionMNIST(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

10类衣服标签的数据集。

每个 training 和 test 示例的标签如下:

LabelDescription
0T-shirt/top
1Trouser
2Pullover
3Dress
4Coat
5Sandal
6Shirt
7Sneaker
8Bag
9Ankle boot

KMNIST

CLASS torchvision.datasets.KMNIST(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

手写日语片假名 数据集。

EMNIST

CLASS torchvision.datasets.EMNIST(rootsplit**kwargs)

MNIST数据库来自更大的数据集,称为NIST特殊数据库19,其包含数字,大写和小写手写字母。 完整NIST数据集的变体,称为扩展MNIST(EMNIST),它遵循用于创建MNIST数据集的相同转换范例。

Parameters:
  • root (string) – Root directory of dataset where EMNIST/processed/training.pt and EMNIST/processed/test.pt exist.
  • split (string) – 数据集有6种不同的分割:byclass,bymerge,balanced,letters,digits 和mnist。此参数指定要使用的参数。
  • train (booloptional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (booloptional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callableoptional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callableoptional) – A function/transform that takes in the target and transforms it.

FakeData

CLASS torchvision.datasets.FakeData(size=1000image_size=(3224224)num_classes=10transform=Nonetarget_transform=Nonerandom_offset=0)

假数据集,返回随机生成的图像并将其作为PIL图像返回。

Parameters:
  • size (intoptional) – 数据集大小。 Default: 1000 images
  • image_size (tupleoptional) – 返回图片的大小。 Default: (3, 224, 224)
  • num_classes (intoptional) – 数据集中类别数。 Default: 10
  • transform (callableoptional
  • target_transform (callableoptional)
  • random_offset (int) – 偏移用于生成每个图像的基于索引的随机种子。 Default: 0

COCO

需要安装Coco API

COCO数据集的使用:https://www.cnblogs.com/q735613050/p/8969452.html

Captions

CLASS torchvision.datasets.CocoCaptions(rootannFiletransform=Nonetarget_transform=None)

MS Coco Captions 数据集。

Parameters:
  • root (string
  • annFile (string) – json注释文件的路径。
  • transform (callableoptional)
  • target_transform (callableoptional

例子:

import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
                        annFile = 'json annotation file',
                        transform=transforms.ToTensor())

print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample

print("Image Size: ", img.size())
print(target)

# output:
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']

     __getitem__(index)

Parameters:index (int) – Index
Returns:Tuple (image, target) target 是 image 的标题列表。
Return type:tuple

Detection

CLASS torchvision.datasets.CocoDetection(rootannFiletransform=Nonetarget_transform=None)

MS Coco Detaction 数据集。

     __getitem__(index)

Parameters:index (int) – Index
Returns:Tuple (image, target). target是coco.loadAnns返回的对象。
Return type:tuple

LSUN

CLASS torchvision.datasets.LSUN(rootclasses='train'transform=Nonetarget_transform=None)

Parameters:
  • root (string)
  • classes (string or list) – One of {‘train’, ‘val’, ‘test’} or a list of categories to load. e,g. [‘bedroom_train’, ‘church_train’].
  • transform (callableoptional
  • target_transform (callableoptional


ImageFolder

CLASS torchvision.datasets.ImageFolder(roottransform=Nonetarget_transform=Noneloader=<function default_loader>)

通用数据加载器,其中图像以这种方式排列:

root/dog/xxx.png

root/dog/xxy.png

root/dog/xxz.png

root/cat/123.png

root/cat/nsdf3.png

root/cat/asd932_.png

Parameters:
  • root (string
  • transform (callableoptional)
  • target_transform (callableoptional
  • loader – 给定路径的图像加载功能。

DatasetFolder

CLASS torchvision.datasets.DatasetFolder(rootloaderextensionstransform=Nonetarget_transform=None)

通用数据加载器,其中样本以这种方式排列:

root/calss_x/xxx.ext

root/calss_x/xxy.ext

root/calss_x/xxz.ext

root/calss_y/123.ext

root/calss_y/nsdf3.ext

root/calss_y/asd32_.ext

Parameters:
  • root (string) – 根目录路径。
  • loader (callable) – 一个在给定路径的情况下加载样本的函数。
  • extensions (list[string]) – 允许的扩展名列表。
  • transform (callableoptional) – 一个函数/转换,它接收一个样本并返回一个转换后的版本。 例如,transforms.RandomCrop用于图像。
  • target_transform – 接收目标并对其进行转换的函数/转换。

Imagenet-12

这应该只使用 ImageFolder 数据集实现。ImageNet大规模视觉识别挑战(ILSVRC)数据集有1000个类别和120万个图像。 图像不需要在任何数据库中进行预处理或打包,但需要将验证图像移动到适当的子文件夹中。

CIFAR

CLASS torchvision.datasets.CIFAR10(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

CIFAR10 数据集由10个类中的60000个32x32彩色图像组成,每个类有6000个图像。 有50000个训练图像和10000个测试图像。数据集分为五个训练 batch 和一个测试 batch ,每个 batch 有10000个图像。 测试 batch 包含来自每个类别的1000个随机选择的图像。训练 batch 以随机顺序包含剩余图像,但是一些训练 batch 可能包含来自一个类别的更多图像而不是另一个类别。 training batch包含来自每个 class 的5000个图像。

Parameters:
  • root (string) – 数据集的根目录,其中目录cifar-10-batches-py存在,如果 download 设置为True 将保存数据。
  • train (booloptional
  • transform (callableoptional
  • target_transform (callableoptional
  • download (booloptional)

CLASS torchvision.datasets.CIFAR100(roottrain=Truetransform=Nonetarget_transform=Nonedownload=False)

CIFAR100 数据集与CIFAR-10类似,不同之处在于它有100个类,每个类包含600个图像。 每个类有500个训练图像和100个测试图像。 CIFAR-100中的100个类被分为20个超类。 每个图像都带有一个“精细”标签(它所属的类)和一个“粗”标签(它所属的超类)。

STL10

CLASS torchvision.datasets.STL10(rootsplit='train'transform=Nonetarget_transform=Nonedownload=False)

STL-10 数据集是用于开发无监督特征学习,深度学习,自学习学习算法的图像识别数据集。它的灵感来自CIFAR-10数据集,但有一些修改。特别地,每个类具有比CIFAR-10更少的标记训练示例,但是提供了非常大的一组未标记示例以在监督训练之前学习图像模型。 主要的挑战是利用未标记的数据(来自与标记数据相似但不同的分布)来构建有用的先验数据。 期望该数据集的更高分辨率(96x96)将使其成为开发更具可扩展性的无监督学习方法的具有挑战性的基准。

Parameters:
  • root (string) –  stl10_binary 根目录存放数据集。
  • split (string) – One of {‘train’, ‘test’, ‘unlabeled’, ‘train+unlabeled’}. 选择对应数据集。
  • transform (callableoptional
  • target_transform (callableoptional)
  • download (booloptional

SVHN

CLASS torchvision.datasets.SVHN(rootsplit='train'transform=Nonetarget_transform=Nonedownload=False)

SVHN数据集(the Street View House Numbers (SVHN) 街景号码数据集)注意:SVHN数据集将标签10分配给数字0。但是,在此数据集中,我们将标签0分配给数字0以与PyTorch损失函数兼容,这些函数期望类标签在[0,C-1]范围内。

Parameters:
  • root (string)
  • split (string) – One of {‘train’, ‘test’, ‘extra’}. 选择对应数据集。 ‘extra’ 是扩展的训练集。
  • transform (callableoptional)
  • target_transform (callableoptional
  • download (booloptional)

PhotoTour

CLASS torchvision.datasets.PhotoTour(rootnametrain=Truetransform=Nonedownload=False)

数据集由1024 x 1024位图(.bmp)图像组成,每个图像包含16 x 16阵列的图像块。每个 patch 采样为64 x 64灰度,具有规范的比例和方向。关联的元数据文件 info.txt 包含匹配信息。 info.txt 的每一行对应一个单独的 patch, patch 在每个位图图像中从左到右,从上到下排序。 info.txt每行的第一个数字是从中采样该 patch 的3D点ID  - 具有相同3D点ID的 patch 从相同的3D点投射到不同的图像中。 info.txt中的第二个数字对应于采样 patch 的图像,目前尚未使用。

     __getitem__(index)

Parameters:index (int) – Index
Returns:(data1, data2, matches)
Return type:tuple

SBU

CLASS torchvision.datasets.SBU(roottransform=Nonetarget_transform=Nonedownload=True)

Im2Text:使用100万张标题照片描述图像。 

Flickr

CLASS torchvision.datasets.Flickr8k(rootann_filetransform=Nonetarget_transform=None)

Parameters:
  • root (string)
  • ann_file (string) – 注释文件的路径。
  • transform (callableoptional) – 一个函数/转换,它接收PIL图像并返回转换后的版本。 E.g, transforms.ToTensor
  • target_transform (callableoptional

     __getitem__(index)

Parameters:index (int) – Index
Returns:Tuple (image, target). target is a list of captions(字幕) for the image.
Return type:tuple

CLASS torchvision.datasets.Flickr30k(rootann_filetransform=Nonetarget_transform=None)

VOC

CLASS torchvision.datasets.VOCSegmentation(rootyear='2012'image_set='train'download=Falsetransform=Nonetarget_transform=None)

Parameters:
  • root (string)
  • year (stringoptional) – 数据集年份,支持从 2007 到 2012。
  • image_set (stringoptional) – 选择要使用的image_set,train,trainval 或 val
  • download (booloptional)
  • transform (callableoptional)
  • target_transform (callableoptional

     __getitem__(index)

Parameters:index (int) – Index
Returns:(image, target) 其中 target 是 image segmentation(分割).
Return type:

tuple

CLASS torchvision.datasets.VOCDetection(rootyear='2012'image_set='train'download=Falsetransform=Nonetarget_transform=None)

     __getitem__(index)

Parameters:index (int) – Index
Returns:(image, target) 其中 target is a dictionary of the XML tree(是XML树的字典).
Return type:tuple

Cityscapes

需要下载 cityscape。

CLASS torchvision.datasets.Cityscapes(rootsplit='train'mode='fine'target_type='instance'transform=Nonetarget_transform=None)

Parameters:
  • root (string) – 目录 leftImg8bit 和 gtFine 或 gtCoarse 所在的数据集的根目录。 
  • split (stringoptional) – 如果 mode =“gtFine”,则图像分为 use,use, traintest or val,否则为 traintrain_extra or val
  • mode (stringoptional) – 要使用的模式,gtFine或gtCoarse
  • target_type (string or listoptional) – 要使用的目标类型(instancesemanticpolygon or color)实例,语义,多边形或颜色。 也可以是一个列表,用于输出具有所有指定目标类型的元组。
  • transform (callableoptional
  • target_transform (callableoptional

例子

获取语义分割目标

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type='semantic')

img, smnt = dataset[0]

获得多个目标

dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
                     target_type=['instance', 'color', 'polygon'])

img, (inst, col, poly) = dataset[0]

在“coarse”集上验证

dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
                     target_type='semantic')

img, smnt = dataset[0]

     __getitem__(index)

Parameters:index (int) – Index
Returns:(image, target) 如果target_type是具有多个项目的列表,target是所有目标类型的元组。否则,如果target_type =“polygon”,则target是json对象,否则是图像分割。
Return type:tuple
  • 13
    点赞
  • 58
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值