[Pytorch系列-33]:数据集 - torchvision与MNIST数据集

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 MNIST数据集

2.1 MNIST数据集介绍

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

2.4 对样本数据预处理

2.5 批量数据读取与显示


第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageNet Flower
  • Imagenet-12
  • STL10

第2章 MNIST数据集

2.1 MNIST数据集介绍

MNIST数据集http://yann.lecun.com/exdb/

 备注 :可以先把样本数据下载本地,以提升程序调试的效率。最终的产品可以远程下载数据。

  • 每张图片大小:28*28.
  • 单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

(1)操作函数MNIST()的解读

MNIST (root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  • root : 文件存放路的根路径,下载的文件存放在该路径下,processed/training.pt 和 processed/test.pt 的主目录
  • train : True = 训练集, False = 测试集
  • target_transform:导入数据时,是否需要对数据格式进行转换,一个函数,原始图片作为输入,返回一个转换后的图片。有时候神经网络所需要的尺寸与数据集提供的尺寸不一致,则可以通过此方法进行转换。
  • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。

(2)代码实例

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
Hello World
1.8.0
False

#2-1 准备数据集
train_data = dataset.MNIST(root = "mnist",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "mnist",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Dataset MNIST
    Number of datapoints: 60000
    Root location: mnist
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 60000

Dataset MNIST
    Number of datapoints: 10000
    Root location: mnist
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

2.4 对样本数据预处理

(1)原图不叠加噪声显示

#原图不叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n不叠加噪声, 原图显示")

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

不叠加噪声, 原图显示

(2)原图叠加噪声

#原图叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

叠加噪声, 平滑显示

 

(3)#叠加噪声,灰度显示图片

#叠加噪声,灰度显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

叠加噪声, 平滑显示

(4)#不叠加噪声,黑白显示图片

#不叠加噪声,黑白显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n不叠加噪声,黑白显示")
plt.imshow(image)
plt.show()
print("numpy image shape:", image.shape)
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

不叠加噪声,黑白显示

2.5 批量数据读取与显示

(1)batch批量图片的读取

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 64,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 64,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_loader), len(train_data)/64)
print(len(test_loader),  len(test_data)/64)
<torch.utils.data.dataloader.DataLoader object at 0x000002461EF4A1C0>
<torch.utils.data.dataloader.DataLoader object at 0x000002461ED66610>
938 937.5
157 156.25

(2)一个batch图片的显示

显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs)
print(images.shape)
print(labels.shape)

print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\n显示样本标签")
#打印图片标签
for i in range(64):
    print(labels[i], end=" ")
    i += 1
    #换行
    if i%8 == 0:
        print(end='\n')

print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([64, 1, 28, 28])
torch.Size([64])
64

合并成一张三通道灰度图片
torch.Size([3, 242, 242])
torch.Size([64])

转换成imshow格式
(242, 242, 3)
torch.Size([64])

显示样本标签
tensor(0) tensor(8) tensor(3) tensor(7) tensor(5) tensor(7) tensor(9) tensor(7) 
tensor(1) tensor(1) tensor(1) tensor(8) tensor(8) tensor(6) tensor(0) tensor(1) 
tensor(4) tensor(8) tensor(1) tensor(3) tensor(3) tensor(6) tensor(4) tensor(4) 
tensor(0) tensor(5) tensor(8) tensor(5) tensor(9) tensor(3) tensor(7) tensor(5) 
tensor(2) tensor(1) tensor(0) tensor(6) tensor(8) tensor(8) tensor(9) tensor(6) 
tensor(1) tensor(3) tensor(5) tensor(3) tensor(4) tensor(4) tensor(3) tensor(1) 
tensor(4) tensor(1) tensor(4) tensor(4) tensor(9) tensor(8) tensor(7) tensor(2) 
tensor(3) tensor(1) tensor(2) tensor(0) tensor(8) tensor(1) tensor(1) tensor(4) 

显示图片



作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489

  • 7
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
torchvision是一个用于构建计算机视觉模型的图形库,它是PyTorch深度学习框架的一部分。torchvision包含了几个重要的模块,包括加载数据的函数和常用的数据集接口(torchvision.datasets)、常用的模型结构(torchvision.models,包含了许多预训练模型,如AlexNet、VGG、ResNet等)、常用的图片变换方法(torchvision.transforms,如裁剪、旋转等)以及其他有用的方法(torchvision.utils)。torchvision通常与PyTorch一起安装,安装命令如下:pip install torchvision。安装时建议使用清华镜像以提高下载速度:pip install torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple/。可以通过导入torchtorchvision模块来验证安装是否成功:import torch,import torchvision。如果没有报错,则表示安装成功。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [torchvision详细介绍](https://blog.csdn.net/frighting_ing/article/details/121863387)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [torchvision-0.9.0](https://download.csdn.net/download/weixin_45235219/87401907)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [pytorchtorchvision安装](https://blog.csdn.net/u013230291/article/details/108487877)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文火冰糖的硅基工坊

你的鼓励是我前进的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值