可视化理解pytorch中transforms的具体作用

pytorch的图像处理库 torchvision 的 transforms 集成了随机翻转、旋转、增强对比度、转化为tensor、转化为图像等功能,用于数据增强,非常便捷。但是,初学者在学习 transforms 时,可能会好奇图像在经过transforms后究竟变成了什么样子,以下利用 matplotlib 来绘制图像,增强理解。

实验环境

  • jupyter notebook
  • python3
  • matplotlib库
  • pytorch库(这里以1.2.0作为演示)
  • pillow库(这里以7.1.2作为演示)
  • torchvision库(这里以0.4.2作为演示)

官方文档

torchvision官方文档:点此查看

常用的方法:

  • Color Jittering:对颜色的数据增强:图像亮度、饱和度、对比度变化(此处对色彩抖动的理解不知是否得当);
  • PCA Jittering:首先按照RGB三个颜色通道计算均值和标准差,再在整个训练集上计算协方差矩阵,进行特征分解,得到特征向量和特征值,用来做PCA Jittering;
  • Random Scale:尺度变换;
  • Random Crop:采用随机图像差值方式,对图像进行裁剪、缩放;包括Scale Jittering方法(VGG及ResNet模型使用)或者尺度和长宽比增强变换;
  • Horizontal/Vertical Flip:水平/垂直翻转;
  • Shift:平移变换;
  • Rotation/Reflection:旋转/仿射变换;
  • Noise:高斯噪声、模糊处理;

可视化

加载所需库

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets.folder import default_loader
#在jupyter notebook中显示图片
%matplotlib inline 

加载图片

img_path = 'D:/12-GitHub/datasets/flickr8k/images\\667626_18933d713e.jpg'
x=default_loader(img_path)
print(x.size)
x

在这里插入图片描述
图片长500像素,宽375像素

定义transform和对应的图片情况

初级版

首先,尝试最简单的transforms.ToTensor()

transform = transforms.Compose([
    transforms.ToTensor(), #将图片变成 Tensor,并且把数值normalize到[0,1]
])

img = transform(x)
print(img.size())
img

Out:

torch.Size([3, 375, 500])
tensor([[[0.9529, 0.9490, 0.9020,  ..., 0.4941, 0.4667, 0.5922],
         [0.8588, 0.8000, 0.9569,  ..., 0.6627, 0.6078, 0.4275],
         [0.9843, 0.8275, 0.7961,  ..., 0.5725, 0.5922, 0.5529],
         ...,
         [0.9059, 0.9176, 0.9412,  ..., 0.4980, 0.6980, 0.6235],
         [0.9216, 0.8902, 0.9373,  ..., 0.8078, 0.8549, 0.7686],
         [0.9765, 0.9373, 0.9294,  ..., 0.9216, 0.8275, 0.9961]],

        [[0.9569, 0.9569, 0.9176,  ..., 0.6314, 0.6000, 0.7216],
         [0.8667, 0.8196, 0.9804,  ..., 0.8000, 0.7412, 0.5569],
         [1.0000, 0.8627, 0.8353,  ..., 0.7176, 0.7333, 0.6941],
         ...,
         [0.9137, 0.9020, 0.8941,  ..., 0.5569, 0.7569, 0.6824],
         [0.8863, 0.8549, 0.8941,  ..., 0.8431, 0.8902, 0.8039],
         [0.8941, 0.8706, 0.8941,  ..., 0.9255, 0.8314, 1.0000]],

        [[1.0000, 1.0000, 0.9529,  ..., 0.6039, 0.6039, 0.7490],
         [0.9137, 0.8353, 0.9804,  ..., 0.7725, 0.7373, 0.5765],
         [1.0000, 0.8510, 0.8000,  ..., 0.6902, 0.7176, 0.6863],
         ...,
         [0.9098, 0.9059, 0.9098,  ..., 0.5686, 0.7686, 0.6941],
         [0.8980, 0.8667, 0.9098,  ..., 0.8627, 0.9098, 0.8235],
         [0.9216, 0.8941, 0.9059,  ..., 0.9451, 0.8510, 1.0000]]])

从 (500,375) 变成了 (3, 375, 500) 的 tensor,接下来是绘图:

plt.figure(figsize=(8,8))
img = transform(x)
img = img.numpy().transpose((1,2,0))
img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述

中级版 - 归一化

然后,尝试加入transforms.Normalize()

transform = transforms.Compose([
    transforms.ToTensor(), #将图片变成 Tensor,并且把数值normalize到[0,1]
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

img = transform(x)
print(img.size())
img

这里Normalize的数值是经验性的,从ImageNet中总结出来的,[0.485, 0.456, 0.406]是均值mean,[0.229, 0.224, 0.225]是方差std。

Out:

torch.Size([3, 375, 500])
tensor([[[ 2.0434,  2.0263,  1.8208,  ...,  0.0398, -0.0801,  0.4679],
         [ 1.6324,  1.3755,  2.0605,  ...,  0.7762,  0.5364, -0.2513],
         [ 2.1804,  1.4954,  1.3584,  ...,  0.3823,  0.4679,  0.2967],
         ...,
         [ 1.8379,  1.8893,  1.9920,  ...,  0.0569,  0.9303,  0.6049],
         [ 1.9064,  1.7694,  1.9749,  ...,  1.4098,  1.6153,  1.2385],
         [ 2.1462,  1.9749,  1.9407,  ...,  1.9064,  1.4954,  2.2318]],

        [[ 2.2360,  2.2360,  2.0609,  ...,  0.7829,  0.6429,  1.1856],
         [ 1.8333,  1.6232,  2.3410,  ...,  1.5357,  1.2731,  0.4503],
         [ 2.4286,  1.8158,  1.6933,  ...,  1.1681,  1.2381,  1.0630],
         ...,
         [ 2.0434,  1.9909,  1.9559,  ...,  0.4503,  1.3431,  1.0105],
         [ 1.9209,  1.7808,  1.9559,  ...,  1.7283,  1.9384,  1.5532],
         [ 1.9559,  1.8508,  1.9559,  ...,  2.0959,  1.6758,  2.4286]],

        [[ 2.6400,  2.6400,  2.4308,  ...,  0.8797,  0.8797,  1.5245],
         [ 2.2566,  1.9080,  2.5529,  ...,  1.6291,  1.4722,  0.7576],
         [ 2.6400,  1.9777,  1.7511,  ...,  1.2631,  1.3851,  1.2457],
         ...,
         [ 2.2391,  2.2217,  2.2391,  ...,  0.7228,  1.6117,  1.2805],
         [ 2.1868,  2.0474,  2.2391,  ...,  2.0300,  2.2391,  1.8557],
         [ 2.2914,  2.1694,  2.2217,  ...,  2.3960,  1.9777,  2.6400]]])

数值和之前有一些不一样了,把它画出来看看

plt.figure(figsize=(8,8))
img = transform(x)
img=img.numpy().transpose((1,2,0))
#img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述
看起来像是曝光过度的图片,把它还原成正常图片:

plt.figure(figsize=(8,8))
img = transform(x)
img=img.numpy().transpose((1,2,0))
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])
img = std*img+mean
img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述

中级版 - 翻转与旋转

继续加入随机翻转 transforms.RandomHorizontalFlip()、随机旋转transforms.RandomRotation() 的操作。

transform = transforms.Compose([
    # transforms.ToPILImage(), # 如果是numpy请取消注释
    transforms.RandomHorizontalFlip(), #随机翻转图片
    transforms.RandomRotation(15), #随机旋转图片
    transforms.ToTensor(), #将图片变成 Tensor,并且把数值normalize到[0,1]
])

img = transform(x)
print(img.size())
img

Out:

torch.Size([3, 375, 500])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

怎么看起来好多地方都是0?让我们可视化看看:

# 多次运行此部分代码,可实现多次随机
plt.figure(figsize=(8,8))
img = transform(x)
img=img.numpy().transpose((1,2,0))
img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述
可以看到,因为随机旋转,所以有一些地方是黑色的。再让我们多随机几次(只需要重复运行img = transform(x),即上方代码,不需要重新运行transform的定义)看看:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

中级版 - 调整亮度、灰度、饱和度、对比度

brightness是亮度,contrast是对比度,saturation是饱和度,hue是色调

transform = transforms.Compose([
    # transforms.ToPILImage(), # 如果是numpy请取消注释
#     transforms.RandomHorizontalFlip(), #随机翻转图片
#     transforms.RandomRotation(15), #随机旋转图片
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(),
    transforms.ToTensor(), #将图片变成 Tensor,并且把数值normalize到[0,1]    
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

img = transform(x)
print(img.size())
img

Out:

torch.Size([3, 375, 500])
tensor([[[0.9020, 0.8980, 0.8588,  ..., 0.4314, 0.4000, 0.5333],
         [0.8196, 0.7333, 0.8902,  ..., 0.5961, 0.5412, 0.3608],
         [0.9176, 0.7569, 0.7333,  ..., 0.5059, 0.5294, 0.4902],
         ...,
         [0.8392, 0.8510, 0.8745,  ..., 0.4353, 0.6314, 0.5569],
         [0.8549, 0.8235, 0.8706,  ..., 0.7412, 0.7882, 0.7020],
         [0.9098, 0.8706, 0.8627,  ..., 0.8667, 0.7725, 0.9176]],

        [[0.8902, 0.8863, 0.8392,  ..., 0.5255, 0.4627, 0.5765],
         [0.7922, 0.7412, 0.9020,  ..., 0.6863, 0.6118, 0.4118],
         [0.9176, 0.7922, 0.7686,  ..., 0.6039, 0.6118, 0.5608],
         ...,
         [0.8510, 0.8431, 0.8314,  ..., 0.4588, 0.6549, 0.5804],
         [0.8196, 0.7961, 0.8275,  ..., 0.7451, 0.7922, 0.7098],
         [0.8392, 0.8196, 0.8275,  ..., 0.8549, 0.7608, 0.9176]],

        [[0.9176, 0.9176, 0.8902,  ..., 0.5686, 0.5412, 0.6863],
         [0.8510, 0.7686, 0.9137,  ..., 0.7373, 0.6784, 0.5137],
         [0.9176, 0.8000, 0.7569,  ..., 0.6549, 0.6706, 0.6314],
         ...,
         [0.8510, 0.8353, 0.8275,  ..., 0.5059, 0.7059, 0.6314],
         [0.8157, 0.7922, 0.8275,  ..., 0.8000, 0.8471, 0.7608],
         [0.8235, 0.8039, 0.8235,  ..., 0.8784, 0.7882, 0.9176]]])

绘图:

plt.figure(figsize=(8,8))
img = transform(x)
img=img.numpy().transpose((1,2,0))
# img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述

同样多随机几次看看:
在这里插入图片描述
在这里插入图片描述

高级版

说是高级版,实际上是把所有东西都混合进来

transform = transforms.Compose([
    # transforms.ToPILImage(), # 如果是numpy请取消注释
    transforms.RandomHorizontalFlip(), #随机翻转图片
    transforms.RandomRotation(15), #随机旋转图片
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(),
    transforms.ToTensor(), #将图片变成 Tensor,并且把数值normalize到[0,1]    
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

img = transform(x)
print(img.size())
img

Out:

torch.Size([3, 375, 500])
tensor([[[-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         ...,
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
         [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],

        [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         ...,
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
         [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],

        [[-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
         [-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
         [-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
         ...,
         [-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
         [-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044],
         [-1.8044, -1.8044, -1.8044,  ..., -1.8044, -1.8044, -1.8044]]])
plt.figure(figsize=(8,8))
img = transform(x)
img=img.numpy().transpose((1,2,0))
# img = np.clip(img,0,1)
plt.imshow(img)

在这里插入图片描述

有人可能会问,处理成这样的图片,看起来似乎不太美妙,真的可以用来训练吗?这个问题,大概只有模型自己知道吧,毕竟图片是给人看的,图像在模型内部只是一堆数据而已。

  • 27
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

iteapoy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值