PyTorch源码分析之torchvision.transforms

PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasetstorchvision.modelstorchvision.transforms。这3个子包的具体介绍可以参考官网:http://pytorch.org/docs/master/torchvision/index.html。具体代码可以参考github:https://github.com/pytorch/vision/tree/master/torchvision。

这篇博客介绍torchvision.transformas。torchvision.transforms这个包中包含resize、crop等常见的data augmentation操作,基本上PyTorch中的data augmentation操作都可以通过该接口实现。该包主要包含两个脚本:transformas.py和functional.py,前者定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。

使用例子:

import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                    torchvision.transforms.RandomCrop(224),                                                                            
                                                    torchvision.transofrms.RandomHorizontalFlip(),
                                                    torchvision.transforms.ToTensor(),
                                                    torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
                                                    ])

Class custom_dataread(torch.utils.data.Dataset):
    def __init__():
        ...
    def __getitem__():
        # use self.transform for input image
    def __len__():
        ...

train_loader = torch.utils.data.DataLoader(
    custom_dataread(transform=train_augmentation),
    batch_size = batch_size, shuffle = True,
    num_workers = workers, pin_memory = True)

这里定义了resize、crop、normalize等数据预处理操作,并最终作为数据读取类custom_dataread的一个参数传入,可以在内部方法__getitem__中实现数据增强操作。

主要代码在transformas.py脚本中,这里仅介绍常见的data augmentation操作,源码如下:
首先是导入必须的模型,这里比较重要的是from . import functional as F,也就是导入了functional.py脚本中具体的data augmentation函数。__all__列表定义了可以从外部import的函数名或类名。

from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
try:
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
import collections
import warnings

from . import functional as F

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize",
"Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", 
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", 
"RandomSizedCrop", "FiveCrop", "TenCrop","LinearTransformation", 
"ColorJitter", "RandomRotation", "Grayscale", "RandomGrayscale"]

Compose这个类是用来管理各个transform的,可以看到主要的__call__方法就是对输入图像img循环所有的transform操作

class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

ToTensor类是实现:Convert a PIL Image or numpy.ndarray to tensor 的过程,在PyTorch中常用PIL库来读取图像数据,因此这个方法相当于搭建了PIL Image和Tensor的桥梁。另外要强调的是在做数据归一化之前必须要把PIL Image转成Tensor,而其他resize或crop操作则不需要。

class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return F.to_tensor(pic)

    def __repr__(self):
        return self.__class__.__name__ + '()'

ToPILImage顾名思义是从Tensor到PIL Image的过程,和前面ToTensor类的相反的操作。

class ToPILImage(object):
    """Convert a tensor or an ndarray to PIL Image.

    Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
    H x W x C to a PIL Image while preserving the value range.

    Args:
        mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
            If ``mode`` is ``None`` (default) there are some assumptions made about the input data:
            1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``.
            2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``.
            3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e,
            ``int``, ``float``, ``short``).

    .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes
    """
    def __init__(self, mode=None):
        self.mode = mode

    def __call__(self, pic):
        """
        Args:
            pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.

        Returns:
            PIL Image: Image converted to PIL Image.

        """
        return F.to_pil_image(pic, self.mode)

    def __repr__(self):
        return self.__class__.__name__ + '({0})'.format(self.mode)

Normalize类是做数据归一化的,一般都会对输入数据做这样的操作,公式也在注释中给出了,比较容易理解。前面提到在调用Normalize的时候,输入得是Tensor,这个从__call__方法的输入也可以看出来了。

class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值