实战|13个Pytorch 图像增强方法总结(附代码)

点击上方“小白学视觉”,选择加"星标"或“置顶”

重磅干货,第一时间送达

作者丨结发授长生@知乎

链接丨https://zhuanlan.zhihu.com/p/559887437

使用数据增强技术可以增加数据集中图像的多样性,从而提高模型的性能和泛化能力。主要的图像增强技术包括:

  • 调整大小

  • 灰度变换

  • 标准化

  • 随机旋转

  • 中心裁剪

  • 随机裁剪

  • 高斯模糊

  • 亮度、对比度调节

  • 水平翻转

  • 垂直翻转

  • 高斯噪声

  • 随机块

  • 中心区域

调整大小

在开始图像大小的调整之前我们需要导入数据(图像以眼底图像为例)。

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T

plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/000001.tif'))
torch.manual_seed(0) # 设置 CPU 生成随机数的 种子 ,方便下次复现实验结果
print(np.asarray(orig_img).shape) #(800, 800, 3)

#图像大小的调整
resized_imgs = [T.Resize(size=size)(orig_img) for size in [128,256]]
# plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('resize:128*128')
ax2.imshow(resized_imgs[0])

ax3 = plt.subplot(133)
ax3.set_title('resize:256*256')
ax3.imshow(resized_imgs[1])

plt.show()
37e8ab676fcf42d0e6b279b1fac69cf2.jpeg

灰度变换

此操作将RGB图像转化为灰度图像。

gray_img = T.Grayscale()(orig_img)
# plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('gray')
ax2.imshow(gray_img,cmap='gray')
c83fc892f853bb9a01d81c39d8232eb8.jpeg

标准化

标准化可以加快基于神经网络结构的模型的计算速度,加快学习速度。

  • 从每个输入通道中减去通道平均值

  • 将其除以通道标准差。

normalized_img = T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(T.ToTensor()(orig_img))
normalized_img = [T.ToPILImage()(normalized_img)]
# plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('normalize')
ax2.imshow(normalized_img[0])

plt.show()
b6371a927cb20f96a8b19ff3b1348808.jpeg

随机旋转

设计角度旋转图像

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

rotated_imgs = [T.RandomRotation(degrees=90)(orig_img)]
print(rotated_imgs)
plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('90°')
ax2.imshow(np.array(rotated_imgs[0]))
f72b4533c871745166fb5d455fd23d74.jpeg

中心剪切

剪切图像的中心区域

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

center_crops = [T.CenterCrop(size=size)(orig_img) for size in (128,64)]

plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('128*128°')
ax2.imshow(np.array(center_crops[0]))

ax3 = plt.subplot(133)
ax3.set_title('64*64')
ax3.imshow(np.array(center_crops[1]))

plt.show()
04f169db118f199ae168f4b164abf8d3.jpeg

随机裁剪

随机剪切图像的某一部分

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

random_crops = [T.RandomCrop(size=size)(orig_img) for size in (400,300)]

plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('400*400')
ax2.imshow(np.array(random_crops[0]))

ax3 = plt.subplot(133)
ax3.set_title('300*300')
ax3.imshow(np.array(random_crops[1]))

plt.show()
48b31c40682e2adf80fda07b05aec0e5.jpeg

高斯模糊

使用高斯核对图像进行模糊变换

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

blurred_imgs = [T.GaussianBlur(kernel_size=(3, 3), sigma=sigma)(orig_img) for sigma in (3,7)]

plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('sigma=3')
ax2.imshow(np.array(blurred_imgs[0]))

ax3 = plt.subplot(133)
ax3.set_title('sigma=7')
ax3.imshow(np.array(blurred_imgs[1]))

plt.show()
5ce206813b64da015fffa2f4656b8011.jpeg

亮度、对比度和饱和度调节

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))
# random_crops = [T.RandomCrop(size=size)(orig_img) for size in (832,704, 256)]
colorjitter_img = [T.ColorJitter(brightness=(2,2), contrast=(0.5,0.5), saturation=(0.5,0.5))(orig_img)]

plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)
ax2 = plt.subplot(122)
ax2.set_title('colorjitter_img')
ax2.imshow(np.array(colorjitter_img[0]))
plt.show()
5f0ab9618c9d5a029883c904f52806f0.jpeg

水平翻转

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

HorizontalFlip_img = [T.RandomHorizontalFlip(p=1)(orig_img)]

plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('colorjitter_img')
ax2.imshow(np.array(HorizontalFlip_img[0]))


plt.show()
cbfabc270efd5173d190992acb85e7ad.jpeg

垂直翻转

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))

VerticalFlip_img = [T.RandomVerticalFlip(p=1)(orig_img)]

plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('VerticalFlip')
ax2.imshow(np.array(VerticalFlip_img[0]))

# ax3 = plt.subplot(133)
# ax3.set_title('sigma=7')
# ax3.imshow(np.array(blurred_imgs[1]))

plt.show()
ef4d13a4636f365b2d433d0662e4cf01.jpeg

高斯噪声

向图像中加入高斯噪声。通过设置噪声因子,噪声因子越高,图像的噪声越大。

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))


def add_noise(inputs, noise_factor=0.3):
    noisy = inputs + torch.randn_like(inputs) * noise_factor
    noisy = torch.clip(noisy, 0., 1.)
    return noisy


noise_imgs = [add_noise(T.ToTensor()(orig_img), noise_factor) for noise_factor in (0.3, 0.6)]
noise_imgs = [T.ToPILImage()(noise_img) for noise_img in noise_imgs]

plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('noise_factor=0.3')
ax2.imshow(np.array(noise_imgs[0]))

ax3 = plt.subplot(133)
ax3.set_title('noise_factor=0.6')
ax3.imshow(np.array(noise_imgs[1]))

plt.show()
fd747973d042d40b2a422631855971bc.jpeg

随机块

正方形补丁随机应用在图像中。这些补丁的数量越多,神经网络解决问题的难度就越大。

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))


def add_random_boxes(img,n_k,size=64):
    h,w = size,size
    img = np.asarray(img).copy()
    img_size = img.shape[1]
    boxes = []
    for k in range(n_k):
        y,x = np.random.randint(0,img_size-w,(2,))
        img[y:y+h,x:x+w] = 0
        boxes.append((x,y,h,w))
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    return img

blocks_imgs = [add_random_boxes(orig_img,n_k=10)]

plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('10 black boxes')
ax2.imshow(np.array(blocks_imgs[0]))


plt.show()
ed7213b78bd01b9adad515f87307497a.jpeg

中心区域

和随机块类似,只不过在图像的中心加入补丁

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T


plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/2.png'))


def add_central_region(img, size=32):
    h, w = size, size
    img = np.asarray(img).copy()
    img_size = img.shape[1]
    img[int(img_size / 2 - h):int(img_size / 2 + h), int(img_size / 2 - w):int(img_size / 2 + w)] = 0
    img = Image.fromarray(img.astype('uint8'), 'RGB')
    return img


central_imgs = [add_central_region(orig_img, size=128)]


plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('')
ax2.imshow(np.array(central_imgs[0]))
#
# ax3 = plt.subplot(133)
# ax3.set_title('20 black boxes')
# ax3.imshow(np.array(blocks_imgs[1]))

plt.show()
b7b2d0d19a6919424585ce1127628797.jpeg

本文仅做学术分享,如有侵权,请联系删文。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
  • 6
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值