深度学习pytorch——图像增广(note)
导入模块
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
读取图片路径
#自己从网上下了一张猫的图片,路径要改变
img = Image.open('D:/###code/Python/TESTpytorch/img/cat.png')
原图
图片展示
def show_images(imgs, num_rows, num_cols, scale=2):
figsize = (num_rows * scale, num_cols * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
for i in range(num_rows):
for j in range(num_cols):
axes[i][j].imshow(imgs[i * num_cols + j])
axes[i][j].axes.get_xaxis().set_visible(False)
axes[i][j].axes.get_yaxis().set_visible(False)
return axes
图像增广的辅助函数apply()
#aug为图像增广的办法
#num_rows=1, num_cols=4指生成2行4列的图片,即对原图作用了8次
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
Y = [aug(img) for _ in range(num_rows * num_cols)]
show_images(Y, num_rows, num_cols)
对图片进行左右翻转
aug1 = torchvision.transforms.RandomHorizontalFlip()
apply(img, aug1)
随机裁减
#200指高宽都是200像素
#ratio指宽和高之比随机取0.5-2
aug2 = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1),
ratio=(0.5, 2))
apply(img, aug2)
变化颜色(亮度, 色度, 对比度, 饱和度)
#bringhtness=0.5 亮度变为原图的50%
aug3 = torchvision.transforms.ColorJitter(brightness=0.5)
apply(img, aug3)
#hue色调
aug4 = torchvision.transforms.ColorJitter(hue=0.5)
apply(img, aug4)
#contrast对比度
aug5 = torchvision.transforms.ColorJitter(contrast=0.5)
apply(img, aug5)
#饱和度
aug6 = torchvision.transforms.ColorJitter(saturation=0.5)
apply(img, aug6)
用一句代码概括
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5,
contrast=0.5, saturation=0.5)
apply(img, color_aug)
叠加多个图像增广
使用torchvision.transforms.Compose()
#aug1 = torchvision.transforms.RandomHorizontalFlip()
#aug2 = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
augs = torchvision.transforms.Compose([aug1, aug2, color_aug])
apply(img, augs)
图片展示,最后要加上
plt.show()