Tensor与Image的相互转换(用PIL包)

最新学习的时候用到了torchvision中的MNIST数据,产生了一些疑问

前言

train_dataset = torchvision.datasets.MNIST(root='../MNISTdata', 
                                           train=True, 
                                           download=True)
                                           
for i, (data, labels) in enumerate(train_dataset):
    if i % 60000 == 0: # MNIST里面一共有60000个数据
        print(type(data))
"""
输出:
<class 'PIL.Image.Image'>
"""


type(train_dataset.data)
"""
输出:
torch.Tensor
"""

如图所示,也就是说我直接从MNIST数据集中取得的数据的数据类型与通过迭代方式取得的不一样,这让我很迷……

那到底是什么原因造成了这样的差异呢?
进到了MNIST的源码里(mnist.py)才发现,它对数据做了图像转换。

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

以上是MNIST的初始化函数

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

以上是MNIST的迭代取数据方法,可以发现有以下这两行代码把数据从tensor转换到了Image类型

img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')

正文

实际的转换方法有很多,以下方法只是其中之一

1、tensor转换为image(PIL)

tensor→numpy→Image

numpy = tensor.numpy()
image = Image.fromarray(numpy())

2、image转换为tensor

采用torchvision.transforms方法

transform = torchvision.transforms.Compose([  
    transforms.ToTensor()])  
tensor = transform(image)
  • 5
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值