【torch杂记】torchvision.transforms中的ToTensor和Normalize

本文详细介绍了torchvision.transforms中的ToTensor和Normalize转换操作。ToTensor将图像数据从0-255的灰度范围转换为0-1的张量,Normalize进一步将0-1范围的数据归一化到(-1,1)区间。这两个转换在预处理深度学习模型输入时至关重要,确保数据符合模型训练的标准化要求。Compose函数则允许将多个转换串联起来,方便地构建复杂的图像处理流程。
摘要由CSDN通过智能技术生成
torchvision.transforms中的ToTensor和Normalize

参考
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.ToTensor()
  • 能够把灰度范围从0-255变换到0-1之间

  • 从源码的角度看,调用的ToTensor的时候是调用这个class的__call__方法,然后ToTensor的call是调用了F的to_tensor方法,F是functional.py

    • class ToTensor(object):
          """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
      
          Converts a PIL Image or numpy.ndarray (H x W x C) in the range
          [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
          if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
          or if the numpy.ndarray has dtype = np.uint8
      
          In the other cases, tensors are returned without scaling.
          """
      
          def __call__(self, pic):
              """
              Args:
                  pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
      
              Returns:
                  Tensor: Converted image.
              """
              # 这里会调用functional中的to_tensor方法
              return F.to_tensor(pic)
      
          def __repr__(self):
              return self.__class__.__name__ + '()'
      
  • to_tensor源码:

    • def to_tensor(pic):
      """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
      See ``ToTensor`` for more details.
      Args:
          pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
      Returns:
          Tensor: Converted image.
      """
      if not(_is_pil_image(pic) or _is_numpy_image(pic)):
          raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
      
      if isinstance(pic, np.ndarray):
          # handle numpy array
          if pic.ndim == 2:
              pic = pic[:, :, None]
      
          img = torch.from_numpy(pic.transpose((2, 0, 1)))
          # backward compatibility
          if isinstance(img, torch.ByteTensor):
              return img.float().div(255)
          else:
              return img
      
      if accimage is not None and isinstance(pic, accimage.Image):
          nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
          pic.copyto(nppic)
          return torch.from_numpy(nppic)
      
      # handle PIL Image
      if pic.mode == 'I':
          img = torch.from_numpy(np.array(pic, np.int32, copy=False))
      elif pic.mode == 'I;16':
          img = torch.from_numpy(np.array(pic, np.int16, copy=False))
      elif pic.mode == 'F':
          img = torch.from_numpy(np.array(pic, np.float32, copy=False))
      elif pic.mode == '1':
          img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
      else:
          img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
      # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
      if pic.mode == 'YCbCr':
          nchannel = 3
      elif pic.mode == 'I;16':
          nchannel = 1
      else:
          nchannel = len(pic.mode)
      img = img.view(pic.size[1], pic.size[0], nchannel)
      # put it from HWC to CHW format
      # yikes, this transpose takes 80% of the loading time/CPU
      img = img.transpose(0, 1).transpose(0, 2).contiguous()
      if isinstance(img, torch.ByteTensor):
          return img.float().div(255)
      else:
          return img
      
  • 对于RGB图片,to_tensor它主要做了四件事

    • 1、img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))将其转换成uint8类型的tensor,且是1维的

      • image-20210702001918510
      • image-20210702002517465
    • 2、nchannel = len(pic.mode)设定通道数,img = img.view(pic.size[1], pic.size[0], nchannel)将img重新变形成3维的tensor

      • image-20210702002357284
      • image-20210702002446945
    • 3、img = img.transpose(0, 1).transpose(0, 2).contiguous(),transpose函数的作用是交换坐标轴,类似二维中的转置的概念,而contiguous函数的作用是让tensor的底层变得连续(相当于改变底层的一维数组)

      • torch.view等方法操作需要连续的Tensor。transpose、permute 操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。torch.view 方法约定了不修改数组本身,只是使用新的形状查看数据。如果我们在 transpose、permute 操作后执行 view,Pytorch 会报错)
    • 4、return img.float().div(255),转成float再除以255(因为之前是uint8,区间是0-255,因此除以255之后就变成0-1区间)

    • 至此他就实现了将pic图片数组转成0-1区间的tensor

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  • 则把0-1变换到(-1,1)

  • class Normalize(object):
        """Normalize a tensor image with mean and standard deviation.
        Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
        will normalize each channel of the input ``torch.*Tensor`` i.e.
        ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
    
        .. note::
            This transform acts out of place, i.e., it does not mutates the input tensor.
    
        Args:
            mean (sequence): Sequence of means for each channel.
            std (sequence): Sequence of standard deviations for each channel.
        """
    
        def __init__(self, mean, std, inplace=False):
            self.mean = mean
            self.std = std
            self.inplace = inplace
    
        def __call__(self, tensor):
            """
            Args:
                tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
    
            Returns:
                Tensor: Normalized Tensor image.
            """
            return F.normalize(tensor, self.mean, self.std, self.inplace)
    
        def __repr__(self):
            return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
    
  • 和ToTensor一样,调用的时候是调用call方法,此时输入的tensor如下图所示,维0-1的范围

  • 接下来是调用functional下的normalize方法,源码如下:

    • def normalize(tensor, mean, std, inplace=False):
          """Normalize a tensor image with mean and standard deviation.
      
          .. note::
              This transform acts out of place by default, i.e., it does not mutates the input tensor.
      
          See :class:`~torchvision.transforms.Normalize` for more details.
      
          Args:
              tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
              mean (sequence): Sequence of means for each channel.
              std (sequence): Sequence of standard deviations for each channely.
      
          Returns:
              Tensor: Normalized Tensor image.
          """
          if not _is_tensor_image(tensor):
              raise TypeError('tensor is not a torch image.')
      
          if not inplace:
              tensor = tensor.clone()
      
          mean = torch.tensor(mean, dtype=torch.float32)
          std = torch.tensor(std, dtype=torch.float32)
          tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
          return tensor
      
      
    • 分析如下

      • 由于inplace默认为false,因此,首先会clone下tensor
      • mean和tensor都是(0.5,0.5,0.5)的元组,然后通过torch.tensor变成float32的一维tensor
        • image-20210702115138113
      • 然后这句就开始变换了tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
        • mean[:, None, None]和std[:, None, None]将其变形成(3,1,1)的tensor然后相减,(0~1-0.5)/0.5=-1~1
      • 故此实现转换
二者的调用_Compose
  • 一般长这样

    • transform = transforms.Compose(
              [transforms.ToTensor(),
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
          )
      
  • 然后她就说传入一个数组,然后call函数遍历调用即可,下面是它的源码

    • class Compose(object):
          """Composes several transforms together.
      
          Args:
              transforms (list of ``Transform`` objects): list of transforms to compose.
      
          Example:
              >>> transforms.Compose([
              >>>     transforms.CenterCrop(10),
              >>>     transforms.ToTensor(),
              >>> ])
          """
      
          def __init__(self, transforms):
              self.transforms = transforms
      
          def __call__(self, img):
              for t in self.transforms:
                  img = t(img)
              return img
      
          def __repr__(self):
              format_string = self.__class__.__name__ + '('
              for t in self.transforms:
                  format_string += '\n'
                  format_string += '    {0}'.format(t)
              format_string += '\n)'
              return format_string
      
  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

椰子奶糖

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值