pytorch下载加载mnist数据集

1.下载mnist

使用torchvision.datasets,其中含有一些常见的MNIST等数据集,使用方式:

train_data=torchvision.datasets.MNIST(
    root='MNIST',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_data=torchvision.datasets.MNIST(
    root='MNIST',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

root:表示下载位置,下载后,会在该位置中新建一个MNIST文件夹,底下还有一个raw文件夹

train:True下载就会是训练集,False下载就会是测试集

transform:表示转换方式

download:表示是否下载

下载完后会生成四个压缩包,分别代表着train的img和label以及test的img和label

变量train_data和test_data的类型分别为'torchvision.datasets.mnist.MNIST',如果想用到pytorch中的进行训练,就必须将变量改为torch

2.torch.utils.data.DataLoader( )

用from torch.utils.data import DataLoader进行导入,

train_load=DataLoader(dataset=train_data,batch_size=100,shuffle=True)
test_load=DataLoader(dataset=test_data,batch_size=100,shuffle=True)

随机加载批量大小为l00数据给train_load和test_load,每个变量都由两部分组成,用迭代器将两部分分开

train_x,train_y=next(iter(train_load))

其中train_x为属性值,type(train_x)=torch.Size([100, 1, 28, 28])#100个,channel为1,长宽为28*28,type(train_y)=torch.size([100])

3.opencv显示图片

import cv2

img=torchvision.utils.make_grid(train_x,nrow=10)#将train_x赋给一个宽为10的网格
#因为cv2显示的图片格式是(size,size,channel),但是img格式为(channel,size,size)
img = img.numpy().transpose(1,2,0)
cv2.imshow('img', img)
cv2.waitKey()
  • 11
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
pytorch中,可以使用torchvision包中的datasets模块来导入MNIST数据集。具体步骤如下: 1. 导入torchvision和torch包 ```python import torch import torchvision ``` 2. 定义数据转换 ```python transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,)) ]) ``` 这里使用了两个数据转换函数,分别是ToTensor()和Normalize()。ToTensor()函数将图像转换为张量,Normalize()函数对张量进行标准化操作,其中参数(0.1307,)和(0.3081,)是MNIST数据集的均值和标准差。 3. 加载数据集 ```python train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.MNIST('data', train=False, download=True, transform=transform) ``` 这里使用了datasets模块中的MNIST类,其中参数train=True表示加载训练集,train=False表示加载测试集。参数download=True表示如果本地没有数据集则自动下载,transform=transform表示使用上面定义的数据转换。 4. 创建数据迭代器 ```python train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) ``` 这里使用了DataLoader类来创建数据迭代器,参数train_dataset表示使用的数据集,batch_size=32表示每次迭代使用的数据量为32,shuffle=True表示在每次迭代之前对数据进行随机打乱。 至此,MNIST数据集就成功导入到pytorch中了。可以通过train_loader和test_loader来获取训练集和测试集的数据。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值