使用make_grid多批次显示网格图像(使用CIFAR数据集介绍)

背景介绍

在机器学习的训练数据集中,我们经常使用多批次的训练来实现更好的训练效果,具体到cv领域,我们的训练数据集通常是[B,C,W,H]格式,其中,B是每个训练批次的大小,C是图片的通道数,如果是1则为灰度图像,如果是3则为彩色图像,W,H分别是图像的像素宽和像素高,在torchvision中,为我们提供了方便的方法显示多通道的图像显示成网格的格式

数据集介绍

这里使用机器学习中经典的CIFAR10数据集,具体可以参考博客CIFAR-10数据集详解与可视化_cifar10数据集可视化-CSDN博客

数据集读取

我们假设已经下载好CIFAR数据集保存在本地计算机的路径中,可以通过CIFAR函数进行读取

# 依赖的库环境
import torchvision
import torch
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor,Compose,Resize

读取CIFAR数据集中的训练数据集

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())

这里的转换方式是使用简单的ToTensor()将图片格式转换成经典的[C,W,H]格式,方便后续的可视化操作

此时我们可以简单地对数据集中的第一张图片进行可视化

img,label = train_dataset[0]
plt.imshow(img.permute(1,2,0))
plt.show()

构造批次数据集

如何构造批次的训练数据集呢?可以通过DataLoader的方式获得批次生成器,也可以通过torch.stack函数自定义地构成

cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)

这里使用列表推导式获得前4张图片组成的数据列表,通过torch.stack指定dim=0进行多个数据的堆加,这里需要注意的是,stack是在指定的维度新增一个维度进行多矩阵的合并,cat是在指定的维度上合并多个矩阵而不增加新的维度

cat与stack的区别

我们来具体看看两者的区别

cat_img = torch.cat([train_dataset[i][0] for i in range(4)],dim=0)
stack_img = torch.stack([train_dataset[i][0] for i in range(4)],dim=0)
print(f'cat_shape:{cat_img.shape}')
print(f'stack_shape:{stack_img.shape}')
cat_shape:torch.Size([12, 32, 32])
stack_shape:torch.Size([4, 3, 32, 32])

train_dataset[i][0]的形状为[3,32,32],当使用cat时,直接在第一维度上进行累加获得[12,32,32];使用stack时,在指定的第一维度上新增一个维度进行累加,有[4,3,32,32]

进行网格化显示

使用torchvision.utils.make_grid函数进行网格格式转换

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=4,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

nrow是指定每一行的图片的数量,这里只有四张图片,所以是4,默认nrow=8

normalize是对图片数据进行标准化

pad_value是对图片间隔之间的像素进行填充的像素值

padding是指定图片之间的像素间隔数量

同时显示100张图片

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(100)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=10,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

批次图片可视化

我们对使用DataLoader生成的批次数据进行可视化

if __name__=='__main__':
    train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
    trainloader = DataLoader(train_dataset,shuffle=True,batch_size=128,num_workers=8)
    trainloader = iter(trainloader)
    trainloader_first_batch = next(trainloader)

    imgs,labels = trainloader_first_batch
    batch_grid = torchvision.utils.make_grid(imgs)
    plt.imshow(batch_grid.permute(1,2,0))
    plt.show()

对训练数据集更好的了解是为了在训练的时候获得更好的模型性能,欢迎大家讨论交流~


  • 21
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值