[Pytorch系列-33]:数据集 - torchvision与MNIST数据集

本文详细介绍了TorchVision库的用途,包括其在处理图像数据集如MNIST上的应用。通过示例展示了如何安装TorchVision、下载和导入MNIST数据集,以及如何对数据进行预处理和批量读取。文章还提供了数据预处理的多种方式,如叠加噪声、灰度显示等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 MNIST数据集

2.1 MNIST数据集介绍

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

2.4 对样本数据预处理

2.5 批量数据读取与显示


第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageNet Flower
  • Imagenet-12
  • STL10

第2章 MNIST数据集

2.1 MNIST数据集介绍

MNIST数据集http://yann.lecun.com/exdb/

 备注 :可以先把样本数据下载本地,以提升程序调试的效率。最终的产品可以远程下载数据。

  • 每张图片大小:28*28.
  • 单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

(1)操作函数MNIST()的解读

MNIST (root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  • root : 文件存放路的根路径,下载的文件存放在该路径下,processed/training.pt 和 processed/test.pt 的主目录
  • train : True = 训练集, False = 测试集
  • target_transform:导入数据时,是否需要对数据格式进行转换,一个函数,原始图片作为输入,返回一个转换后的图片。有时候神经网络所需要的尺寸与数据集提供的尺寸不一致,则可以通过此方法进行转换。
  • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。

(2)代码实例

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
Hello World
1.8.0
False

#2-1 准备数据集
train_data = dataset.MNIST(root = "mnist",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "mnist",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Dataset MNIST
    Number of datapoints: 60000
    Root location: mnist
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 60000

Dataset MNIST
    Number of datapoints: 10000
    Root location: mnist
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

2.4 对样本数据预处理

(1)原图不叠加噪声显示

#原图不叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n不叠加噪声, 原图显示")

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

不叠加噪声, 原图显示

(2)原图叠加噪声

#原图叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5

叠加噪声, 平滑显示

 

(3)#叠加噪声,灰度显示图片

#叠加噪声,灰度显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + mean

plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

叠加噪声, 平滑显示

(4)#不叠加噪声,黑白显示图片

#不叠加噪声,黑白显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0) 
print("numpy image shape:", image.shape)
print("numpy image label:", label)

print("\n不叠加噪声,黑白显示")
plt.imshow(image)
plt.show()
print("numpy image shape:", image.shape)
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5

三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5

三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5

不叠加噪声,黑白显示

2.5 批量数据读取与显示

(1)batch批量图片的读取

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 64,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 64,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_loader), len(train_data)/64)
print(len(test_loader),  len(test_data)/64)
<torch.utils.data.dataloader.DataLoader object at 0x000002461EF4A1C0>
<torch.utils.data.dataloader.DataLoader object at 0x000002461ED66610>
938 937.5
157 156.25

(2)一个batch图片的显示

显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs)
print(images.shape)
print(labels.shape)

print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\n显示样本标签")
#打印图片标签
for i in range(64):
    print(labels[i], end=" ")
    i += 1
    #换行
    if i%8 == 0:
        print(end='\n')

print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([64, 1, 28, 28])
torch.Size([64])
64

合并成一张三通道灰度图片
torch.Size([3, 242, 242])
torch.Size([64])

转换成imshow格式
(242, 242, 3)
torch.Size([64])

显示样本标签
tensor(0) tensor(8) tensor(3) tensor(7) tensor(5) tensor(7) tensor(9) tensor(7) 
tensor(1) tensor(1) tensor(1) tensor(8) tensor(8) tensor(6) tensor(0) tensor(1) 
tensor(4) tensor(8) tensor(1) tensor(3) tensor(3) tensor(6) tensor(4) tensor(4) 
tensor(0) tensor(5) tensor(8) tensor(5) tensor(9) tensor(3) tensor(7) tensor(5) 
tensor(2) tensor(1) tensor(0) tensor(6) tensor(8) tensor(8) tensor(9) tensor(6) 
tensor(1) tensor(3) tensor(5) tensor(3) tensor(4) tensor(4) tensor(3) tensor(1) 
tensor(4) tensor(1) tensor(4) tensor(4) tensor(9) tensor(8) tensor(7) tensor(2) 
tensor(3) tensor(1) tensor(2) tensor(0) tensor(8) tensor(1) tensor(1) tensor(4) 

显示图片



作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489

### 解决方案 #### 1. 环境一致性检查 确保当前使用的 Python 解释器安装 PyTorch 的环境一致。如果使用的是 Anaconda 创建的虚拟环境,则需确认 PyCharm 中设置的解释器路径指向该虚拟环境下的 Python 可执行文件,而不是默认的全局 Python 路径[^1]。 #### 2. 版本兼容性验证 PyTorch 对 CUDA、cuDNN 和其他依赖项有严格的版本要求。如果这些组件之间的版本不匹配,可能会导致 `import` 失败。建议按照官方文档推荐的方式重新安装合适的版本组合[^3]。例如: ```bash conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch ``` 此命令会自动处理依赖关系并安装适合的 cuDNN 和 CUDA 工具链。 #### 3. 验证安装成功否 即使通过 `pip list` 或者 Conda 列表看到 PyTorch 安装完成,仍需进一步测试其可用性。可以在终端启动 Python 并运行以下代码来验证: ```python import torch print(torch.__version__) print(torch.cuda.is_available()) ``` 上述脚本能打印出 PyTorch 的版本号以及 GPU 是否可用的信息。如果这两步均正常工作,则表明基础环境配置无误[^5]。 #### 4. PyCharm 设置调整 对于 IDE 类型的问题,通常是因为项目未正确关联到目标 Python 解释器所致。具体操作如下: - 打开 **Settings/Preferences | Project: <project_name> | Python Interpreter**; - 如果列表里没有所需的 conda env,点击齿轮图标选择 “Add...”,然后指定 anaconda 下对应的 python.exe 文件位置[^4]。 #### 5. 清理残留数据 有时旧版库或错误索引可能干扰新安装的结果。可以考虑清理 pip 缓存或者删除 site-packages 目录下有关于 torch 的目录后再重试一次完整的安装流程。 --- ### 示例代码片段 以下是用于检测 Torch 功能的一个简单例子: ```python if __name__ == "__main__": import torch device = 'cuda' if torch.cuda.is_available() else 'cpu' tensor_example = torch.tensor([1., 2., 3.], dtype=torch.float).to(device) print(f"Torch version installed is {torch.__version__}") print(f"Using device type '{device}'") ``` ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文火冰糖的硅基工坊

你的鼓励是我前进的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值