Pytorch之下载数据集

如果你的torchvision还没有装好,可以参看https://blog.csdn.net/qq_37385726/article/details/81744485

(应对于WIndows下Python3.6,cuda=none)

 

目录

1.代码

MNIST

CIFAR

效果


 

 

1.代码

Pytorch中有很多常用的数据集模块,预先保存在了torchvision.datasets中,要用的时候下载即可。

torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

现介绍MNIST和CIFAR的下载方式

MNIST

dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)

参数说明:

- root : processed/training.pt 和 processed/test.pt 的主目录

- train : True = 训练集, False= 测试集

- download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,就赋值为False,不再重复下载

CIFAR

dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
 
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)

参数说明:

- root : cifar-10-batches-py 的根目录

- train : True = 训练集, False = 测试集

- transform : 定义对于下载到的数据的数据变化形式,利用torchvision.transforms中的数据变换函数处理

- download : True = 从互联上下载数据,并将其放在root目录下。如果数据集已经下载,什么都不干。

 

import torchvision.datasets as dsets
import  torchvision.transforms as transforms
from  PIL import Image
 
#super parameters
DOWNLOAD = True
 
 
#定义数据变换
transform1 = transforms.ToTensor()  #可以把下载到的数据转化成张量格式
 
#transforms.Compose()定义多重数据变化
transform2 = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  #归一化[-1,1]
 
mT_trainset = dsets.MNIST(root='./MNIST/Tensor/training',train=True,transform=transform1,download=DOWNLOAD)
mT_testset = dsets.MNIST(root='./MNIST/Tensor/test',train=False,transform=transform1,download=DOWNLOAD)
cT_trainset = dsets.CIFAR10(root='./CIFAR10/Tensor/training',train=True,transform=transform1,download=DOWNLOAD)
cT_testset = dsets.CIFAR10(root='./CIFAR10/Tensor/test',train=False,transform=transform1,download=DOWNLOAD)
 
 
mN_trainset = dsets.MNIST(root='./MNIST/Normal/training',train=True,transform=transform2,download=DOWNLOAD)
mN_testset = dsets.MNIST(root='./MNIST/Normal/test',train=False,transform=transform2,download=DOWNLOAD)
cN_trainset = dsets.CIFAR10(root='./CIFAR10/Normal/training',train=True,transform=transform2,download=DOWNLOAD)
cN_testset = dsets.CIFAR10(root='./CIFAR10/Normal/test',train=False,transform=transform2,download=DOWNLOAD)

 

效果

  • 3
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值