PyTorch - 数据集介绍(mnist、CIFAR10、CIFAR100)

参考自官网:torchvision.datasets

总介绍

torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

详细介绍(以mnist手写数字集为例)

  • 数据集介绍
    60000个训练数据,10000个测试数据,每张图片大小28*28。
    单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)
  • 参数列表
    MNIST(root, train=True, transform=None, target_transform=None, download=False)
    参数说明:
    • root : processed/training.pt 和 processed/test.pt 的主目录
    • train : True = 训练集, False = 测试集
    • target_transform:一个函数,原始图片作为输入,返回一个转换后的图片
    • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。
  • 除此以外还需要对target_transform进一步了解:
    一个函数,输入为target,输出对其的转换。
    torchvision.transforms.Compose(transforms)
    例如:
torchvision.transforms.Compose([
    torchvision.transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
    torchvision.transforms.CenterCrop(10),# 将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
    torchvision.transforms.ToTensor(), # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloadTensor
    torchvision.transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化(正则化)至[-1, 1]
 ])
  • 代码使用
import torchvision

# 获取数据集
train_data = torchvision.datasets.MNIST(root='mnist', 
                                        train=True, 
                                        transform = torchvision.transforms.ToTensor(), 
                                        download=True)
test_data = torchvision.datasets.MNIST(root='mnist', 
                                       train=False, 
                                       transform = torchvision.transforms.ToTensor(), 
                                       download=True)

# 属性测试
num_sample = train_data.__len__()
print(num_sample) # 60000
item = train_data.__getitem__(0)

或者

import torchvision


# 数据集的预处理
data_tf = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5],[0.5])
    ])

data_path = r'./mnist'
# 获取数据集
train_data = torchvision.datasets.MNIST(data_path, train=True, transform = data_tf, download=True)
test_data = torchvision.datasets.MNIST(data_path, train=False, transform = data_tf, download=True)

其他数据集

  • CIFAR10

    • API
      CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
    • 介绍:
      该数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
    • 图示 在这里插入图片描述
  • CIFAR100

    • API
      CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
    • 介绍:
      这个数据集就像CIFAR-10,除了它有100个类,每个类包含600个图像。,每类各有500个训练图像和100个测试图像。CIFAR-100中的100个类被分成20个超类。每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类)
      以下是CIFAR-100中的类别列表:
超类类别
水生哺乳动物海狸,海豚,水獭,海豹,鲸鱼
水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼
花卉兰花,罂粟花,玫瑰,向日葵,郁金香
食品容器瓶子,碗,罐子,杯子,盘子
水果和蔬菜苹果,蘑菇,橘子,梨,甜椒
家用电器时钟,电脑键盘,台灯,电话机,电视机
家用家具床,椅子,沙发,桌子,衣柜
昆虫蜜蜂,甲虫,蝴蝶,毛虫,蟑螂
大型食肉动物熊,豹,狮子,老虎,狼
大型人造户外用品桥,城堡,房子,路,摩天大楼
大自然的户外场景云,森林,山,平原,海
大杂食动物和食草动物骆驼,牛,黑猩猩,大象,袋鼠
中型哺乳动物狐狸,豪猪,负鼠,浣熊,臭鼬
非昆虫无脊椎动物螃蟹,龙虾,蜗牛,蜘蛛,蠕虫
宝贝,男孩,女孩,男人,女人
爬行动物鳄鱼,恐龙,蜥蜴,蛇,乌龟
小型哺乳动物仓鼠,老鼠,兔子,母老虎,松鼠
树木枫树,橡树,棕榈,松树,柳树
车辆1自行车,公共汽车,摩托车,皮卡车,火车
车辆2割草机,火箭,有轨电车,坦克,拖拉机
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值