参考自官网:torchvision.datasets
总介绍
torchvision.datasets中包含了以下数据集
- MNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
详细介绍(以mnist
手写数字集为例)
- 数据集介绍
60000个训练数据,10000个测试数据,每张图片大小28*28。
单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28) - 参数列表
MNIST(root, train=True, transform=None, target_transform=None, download=False)
参数说明:- root : processed/training.pt 和 processed/test.pt 的主目录
- train : True = 训练集, False = 测试集
- target_transform:一个函数,原始图片作为输入,返回一个转换后的图片
- download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。
- 除此以外还需要对
target_transform
进一步了解:
一个函数,输入为target,输出对其的转换。
torchvision.transforms.Compose(transforms)
例如:
torchvision.transforms.Compose([
torchvision.transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
torchvision.transforms.CenterCrop(10),# 将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
torchvision.transforms.ToTensor(), # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor
torchvision.transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化(正则化)至[-1, 1]
])
- 代码使用
import torchvision
# 获取数据集
train_data = torchvision.datasets.MNIST(root='mnist',
train=True,
transform = torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.MNIST(root='mnist',
train=False,
transform = torchvision.transforms.ToTensor(),
download=True)
# 属性测试
num_sample = train_data.__len__()
print(num_sample) # 60000
item = train_data.__getitem__(0)
或者
import torchvision
# 数据集的预处理
data_tf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5],[0.5])
])
data_path = r'./mnist'
# 获取数据集
train_data = torchvision.datasets.MNIST(data_path, train=True, transform = data_tf, download=True)
test_data = torchvision.datasets.MNIST(data_path, train=False, transform = data_tf, download=True)
其他数据集
-
CIFAR10
- API
CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
- 介绍:
该数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。 - 图示
- API
-
CIFAR100
- API
CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
- 介绍:
这个数据集就像CIFAR-10,除了它有100个类,每个类包含600个图像。,每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)
以下是CIFAR-100中的类别列表:
- API
超类 | 类别 |
---|---|
水生哺乳动物 | 海狸,海豚,水獭,海豹,鲸鱼 |
鱼 | 水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼 |
花卉 | 兰花,罂粟花,玫瑰,向日葵,郁金香 |
食品容器 | 瓶子,碗,罐子,杯子,盘子 |
水果和蔬菜 | 苹果,蘑菇,橘子,梨,甜椒 |
家用电器 | 时钟,电脑键盘,台灯,电话机,电视机 |
家用家具 | 床,椅子,沙发,桌子,衣柜 |
昆虫 | 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂 |
大型食肉动物 | 熊,豹,狮子,老虎,狼 |
大型人造户外用品 | 桥,城堡,房子,路,摩天大楼 |
大自然的户外场景 | 云,森林,山,平原,海 |
大杂食动物和食草动物 | 骆驼,牛,黑猩猩,大象,袋鼠 |
中型哺乳动物 | 狐狸,豪猪,负鼠,浣熊,臭鼬 |
非昆虫无脊椎动物 | 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫 |
人 | 宝贝,男孩,女孩,男人,女人 |
爬行动物 | 鳄鱼,恐龙,蜥蜴,蛇,乌龟 |
小型哺乳动物 | 仓鼠,老鼠,兔子,母老虎,松鼠 |
树木 | 枫树,橡树,棕榈,松树,柳树 |
车辆1 | 自行车,公共汽车,摩托车,皮卡车,火车 |
车辆2 | 割草机,火箭,有轨电车,坦克,拖拉机 |