TypeError: img should be PIL Image. Got class torch.Tensor

背景

在pytorch中使用MNIST数据集,进行可视化,代码如下:

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

# part 1: 数据集的获取,torch中提供了数据集的相关API
mnist_train_dataset = datasets.MNIST(root="./data/",
                                      train=True,
                                      download=True,
                                      transform=
                                        transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5],std=[0.5]),transforms.Resize((28,28))])
                                    )
                                        
mnist_test_dataset = datasets.MNIST(root="./data/",
                                      train=False,
                                      download=True,
                                      transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28,28)))
                  )

# part 2: 数据装载, dataloader
data_loader_train = torch.utils.data.DataLoader(
    dataset=mnist_train_dataset,
    batch_size=128,
    shuffle=True
)

data_loader_test = torch.utils.data.DataLoader(
    dataset=mnist_test_dataset,
    batch_size = 1,
    shuffle=True
)


# part 3: 数据可视化,检查数据
images,labels = next(iter(data_loader_train))
# TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1,2,0)
std=mean=[0.5,0.5,0.5]
img = img * std + mean
# 直接imshow会报错:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
# 意思是需要归一化处理
print([int(labels[i].numpy()) for i,label in enumerate(labels)])
plt.imshow(img)
plt.show()

运行会出现以下报错:

Traceback (most recent call last):
  File "d:/GitHub/studyNote/pytorch基础/mnist.torch.py", line 45, in <module>
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 196, in __call__
    return F.resize(img, self.size, self.interpolation)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 229, in resize
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

思考:

他需要PIL格式的图片,正好transforms中有一个方法为:transforms.ToPILImage(), 然后就变成了

transform=transforms.Compose([
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5
                              transforms.Resize([28,28]),
                              transforms.ToPILImage()
                             ])

但是还是报错:

Traceback (most recent call last):
  File "d:/GitHub/studyNote/pytorch基础/mnist.torch.py", line 45, in <module>
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 196, in __call__
    return F.resize(img, self.size, self.interpolation)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 229, in resize
    raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

去bing上查询到stackoverflow上发现一个类似的错误:

train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

TypeError: img should be PIL Image. Got <class ‘torch.Tensor’>

– from https://stackoverflow.com/questions/57079219/img-should-be-pil-image-got-class-torch-tensor

下边大神的解决方案是:

transforms.RandomHorizontalFlip() works on PIL.Images, not torch.Tensor. In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip(), which results in tensor.

transforms.RandomHorizontalFlip() works on PIL.Images, not torch.Tensor. In your code above, you are applying transforms.ToTensor() prior to transforms.RandomHorizontalFlip(), which results in tensor.

But, as per the official pytorch documentation here,

transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.

So, just change the order of your transformation in above code, like below:

train_transforms = transforms.Compose([transforms.Resize(255), 
                                       transforms.CenterCrop(224),  
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(), 
                                       transforms.Normalize([0.485, 0.456, 0.406], 										[0.229, 0.224, 0.225])])

发现是顺序问题,需要调换,将ToTensor放在RandomHorizontalFlip之后。

解决:

这个问题我们也采用相同方法尝试。

transform=transforms.Compose([
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5
                              transforms.Resize([28,28]),
                              transforms.ToPILImage()
                             ])

改为:

transform=transforms.Compose([
    						transforms.Resize([28,28]),
    						transforms.ToTensor(),                               
                              transforms.Normalize(mean=[0.5],std=[0.5])
                              # transforms.ToPILImage()
                             ])

对这个顺序比较好奇,又尝试了一下:

transform=transforms.Compose([
                              transforms.Scale([28,28]),
                              transforms.Normalize(mean=[0.5],std=[0.5]),
                              transforms.ToTensor()
                             ])

发现报错:

  File "d:/GitHub/studyNote/pytorch基础/mnist.torch.py", line 47, in <module>
    images,labels = next(iter(data_loader_train))
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 560, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\datasets\mnist.py", line 95, in __getitem__
    img = self.transform(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 61, in __call__
    img = t(img)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 164, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "E:\ProgramData\Miniconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 201, in normalize
    raise TypeError('tensor is not a torch image.')
TypeError: tensor is not a torch image.

看来ToTensor需要在Normalize之前才行。

大家如果有新的发现可以在评论补充

  • 19
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

*pprp*

如果有帮助可以打赏一杯咖啡

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值