【pytorch】带batch的tensor类型图像显示

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。那么如何显示dataloader里面带batchtensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib

pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:
在这里插入图片描述
数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成),用法示例如下:

>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例

#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
    root='./dataset/',
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]),
    download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
    dataset=train_file,
    batch_size=9,
    shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
    plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。所以,如果你想查看训练集的原始图像,还得反标准化。

  • 标准化:image = (image-mean)/std
  • 反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:
在这里插入图片描述

最终效果

在这里插入图片描述

引用参考

https://pytorch.org/docs/stable/tensors.html
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html

  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
好的,我可以帮你解答这个问题。 首先,PyTorch提供了许多预训练模型,包括图像分类、目标检测、语义分割等任务的模型。这些预训练模型在大规模数据集上进行了训练,可以用来进行迁移学习,快速实现新的任务。我们可以使用其中的图像分类模型,来对单张数字图像进行预测。 以下是使用PyTorch的预训练图像分类模型,实现对单张数字图像图像分类预测的示例代码: ```python import torch import torchvision import torchvision.transforms as transforms from PIL import Image # 加载预训练模型 model = torchvision.models.resnet18(pretrained=True) # 设置模型为评估模式 model.eval() # 定义转换 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 加载图像并进行转换 image = Image.open('test.jpg') image = transform(image) # 增加一个维度,使其变成一个batch的形式 image = image.unsqueeze(0) # 对图像进行预测 with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs.data, 1) print('预测结果为:', predicted.item()) ``` 这段代码使用了ResNet-18作为预训练模型,对输入的图像进行了预测,并输出了预测结果。需要注意的是,这里的图像需要进行预处理,包括缩放、裁剪、标准化等操作,这些操作可以通过PyTorch的transforms模块进行实现。同时,由于预训练模型在训练时使用的是RGB格式的图像,因此需要将输入图像转换为RGB格式。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Xavier Jiezou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值