torchvision中transforms的使用

目录

1. transforms的结构及用法

1.1 transforms.ToTensor类

 1.2 transforms.Normalize 类

 1.3 transforms.Resize类

1.4 transforms.Compose类

2. dataset 和 transforms 联合使用


transforms主要是对图片进行一些变换.

1. transforms的结构及用法

torchvision中的transforms实际是一个transforms.py文件.

  (图片 from B站up主:我是土堆)

1.1 transforms.ToTensor类

        输入图片-->输出tensor类型图片:

class ToTensor:
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
    or if the numpy.ndarray has dtype = np.uint8

    In the other cases, tensors are returned without scaling.

    .. note::
        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
        transforming target image masks. See the `references`_ for implementing the transforms for image masks.

    .. _r

我们通过 transforms.ToTensor解决两个问题:

1) how:如何使用

from torchvision import transforms
from PIL import Image

img_path = r"C:\Users\L\Desktop\deeplearning\BelgiumTSC\Training\00000\01153_00000.ppm"
img = Image.open(img_path)
print(img)
tensor_trans = transforms.ToTensor()  # 实例化
tensor_img = tensor_trans(img)
print(tensor_img)

结果: 

2) why:为什么需要转换为tensor数据类型

   tensor数据类型包装了神经网络理论基础的参数

 1.2 transforms.Normalize 类

        对tensor数据类型图片进行归一化.

        公式:output[channel] = (input[channel] - mean[channel]) / std[channel]

class Normalize(torch.nn.Module):
    """Normalize a tensor image with mean and standard deviation.
    This transform does not support PIL Image.
    Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
    channels, this transform will normalize each channel of the input
    ``torch.*Tensor`` i.e.,
    ``output[channel] = (input[channel] - mean[channel]) / std[channel]``

    .. note::
        This transform acts out of place, i.e., it does not mutate the input tensor.

    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation in-place.

    """

实战:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

image_path = r"C:\Users\L\Desktop\deeplearning\BelgiumTSC\Training\00000\01153_00000.ppm"
writer = SummaryWriter("logs")
img = Image.open(image_path)
print(img)

# ToTensor
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img)
writer.add_image("ToTensor", img_tensor)

# Normalize: output[channel] = (input[channel] - mean[channel]) / std[channel]
print(img_tensor[0][0][0])
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
print(img_norm[0][0][0])
writer.add_image("Normalize", img_norm)

writer.close()

 结果:

 1.3 transforms.Resize类

        图像缩放.

class Resize(torch.nn.Module):
    """Resize the input image to the given size.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

    .. warning::
        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
        closer.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size).

            .. note::
                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
            ``InterpolationMode.BICUBIC`` are supported.
            For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted,
            but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum.
        max_size (int, optional): The maximum allowed for the longer edge of
            the resized image: if the longer edge of the image is greater
            than ``max_size`` after being resized according to ``size``, then
            the image is resized again so that the longer edge is equal to
            ``max_size``. As a result, ``size`` might be overruled, i.e the
            smaller edge may be shorter than ``size``. This is only supported
            if ``size`` is an int (or a sequence of length 1 in torchscript
            mode).
        antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias
            is always used. If ``img`` is Tensor, the flag is False by default and can be set to True for
            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` modes.
            This can help making the output for PIL images and tensors closer.
    """

 实战:

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

image_path = r"C:\Users\L\Desktop\deeplearning\BelgiumTSC\Training\00000\01153_00000.ppm"
img = Image.open(image_path)

# Resize
print("originla image: ", img.size)
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img)  
print("resized image: ", img_resize.size)

结果:

1.4 transforms.Compose类

PS:transforms.Compose()中的参数需要是一个列表,即[数据1,数据2,...],并且列表中的数据必须是transforms类型,即Compose([transforms参数1,transforms参数1,...].

class Compose:
    """Composes several transforms together. This transform does not support torchscript.
    Please, see the note below.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.PILToTensor(),
        >>>     transforms.ConvertImageDtype(torch.float),
        >>> ])

    .. note::
        In order to script the transformations, please use ``torch.nn.Sequential`` as below.

        >>> transforms = torch.nn.Sequential(
        >>>     transforms.CenterCrop(10),
        >>>     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        >>> )
        >>> scripted_transforms = torch.jit.script(transforms)

        Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
        `lambda` functions or ``PIL.Image``.

    """

实战

from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

image_path = r"C:\Users\L\Desktop\deeplearning\BelgiumTSC\Training\00000\01153_00000.ppm"
img = Image.open(image_path)

# ToTensor
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img)

# Compose - resize
trans_resize_2 = transforms.Resize(512)
trans_compose = transforms.Compose([trans_resize_2, trans_tensor])
img_compose = trans_compose(img)
print("input image: ", img.size)
print("composed image: ", img_compose.shape)

结果:

2. dataset 和 transforms 联合使用

CIFAR10数据集:

 

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
# 数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

writer = SummaryWriter("CIFAR10_logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

结果: 

 

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值