【torch加载数据集并显示】

标题【torch加载数据集并显示】

注:以下代码是Juyter notebook实现

torchvision包含以下数据集:

MNIST
COCO(用于图像标注和目标检测)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

1. 加载MNIST数据集

import torchvision
from torchvision import transforms
import torch
from torchvision.utils import save_image

dataset = torchvision.datasets.MNIST(root = 'data',
                                     train= True,
                                     transform= transforms.ToTensor(),
                                     download= True)
data_loader = torch.utils.data.DataLoader(dataset = dataset,
                                          batch_size = 100,
                                          shuffle = True)                                    
                                 
'''
MNIST函数原型:

MNIST(root, train=True, transform=None, target_transform=None, download=False)

root:下载到哪里,或从哪里读取

train:True为训练集,False为测试集

download:是否从网络下载
save_image函数原型:
torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
参考:
https://blog.csdn.net/weixin_43723625/article/details/108159190?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167939039516800186592443%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167939039516800186592443&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-1-108159190-null-null.142^v74^control_1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=torchvision.utils.save_image%28%29%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E&spm=1018.2226.3001.4187

类似函数:make_grid
参考:https://blog.csdn.net/u012343179/article/details/83007296?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522167939066516800182128063%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=167939066516800182128063&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~first_rank_ecpm_v1~rank_v31_ecpm-2-83007296-null-null.142^v74^control_1,201^v4^add_ask,239^v2^insert_chatgpt&utm_term=torchvision.utils.make_grid%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E&spm=1018.2226.3001.4187
'''

2. 拼接保存为bmp格式

for x,y in data_loader:
   print(x.size(), x.dtype, x.max())
   print(y)
   save_image(x, 'mnist100.bmp', nrow= 10, pad_value=1) # x为图像的tensor数据,'mnist100.bmp'图像命名,nrow=10行,pad_value=1子图间隔为1 
   break # 只输出前100个作为展示

输出:

torch.Size([100, 1, 28, 28]) torch.float32 tensor(1.)
tensor([3, 4, 8, 2, 1, 2, 7, 9, 3, 0, 0, 1, 4, 1, 5, 8, 4, 7, 8, 5, 6, 6, 6, 0,
        6, 0, 0, 8, 1, 1, 8, 4, 1, 1, 4, 8, 6, 1, 0, 2, 4, 2, 8, 9, 5, 2, 6, 8,
        4, 1, 1, 7, 4, 8, 1, 3, 8, 1, 5, 9, 3, 3, 7, 2, 1, 2, 4, 6, 7, 4, 3, 0,
        2, 9, 7, 4, 0, 7, 6, 9, 5, 0, 8, 2, 6, 7, 7, 2, 8, 2, 6, 9, 2, 0, 2, 1,
        2, 2, 0, 9])

3. 显示保存的图片

import matplotlib.pyplot as plt
import cv2
img = cv2.imread('mnist100.bmp')
plt.imshow(img)
plt.xticks([])  # 去掉横坐标值
plt.yticks([])  # 去掉纵坐标值
plt.show()

输出:
在这里插入图片描述

4. 扩展:加载彩色CIFAR10

导入先前的必备库
输出彩色拼图

dataset = torchvision.datasets.CIFAR10(root = 'cifar10data',
                                     train= True,
                                     transform= transforms.ToTensor(),
                                     download= True)
data_loader = torch.utils.data.DataLoader(dataset = dataset,
                                          batch_size = 100,
                                          shuffle = True)   
for x,y in data_loader:
    print(x.size(), x.dtype, x.max())
    print(y)
    save_image(x, 'cifar100.jpg', nrow= 10, pad_value=1) # x为图像的tensor数据,'mnist100.bmp'图像命名,nrow=10行,pad_value=1子图间隔为1 
    break # 只输出前100个作为展示                                                                     

输出:

torch.Size([100, 3, 32, 32]) torch.float32 tensor(1.)
tensor([6, 4, 5, 3, 1, 4, 5, 9, 3, 7, 3, 9, 6, 4, 7, 3, 4, 0, 0, 8, 3, 0, 5, 2,
        3, 9, 1, 0, 7, 9, 7, 7, 1, 8, 1, 2, 8, 9, 1, 8, 9, 0, 0, 2, 2, 2, 5, 2,
        1, 2, 1, 3, 9, 8, 0, 3, 0, 7, 7, 0, 8, 0, 1, 3, 8, 8, 3, 9, 3, 3, 8, 6,
        5, 5, 1, 5, 6, 2, 3, 9, 5, 5, 5, 2, 6, 4, 8, 7, 3, 4, 1, 9, 9, 3, 0, 7,
        6, 7, 5, 3])

显示彩色拼图

import matplotlib.pyplot as plt
import cv2
img = cv2.imread('cifar100.bmp')
plt.imshow(img)
plt.xticks([])  # 去掉横坐标值
plt.yticks([])  # 去掉纵坐标值

plt.show()

输出:
在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值