此示例演示了torchvision.transforms模块中提供的各种变换。
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('assets') / 'astronaut.jpg')
# if you change the seed, make sure that the randomly-applied transforms
# properly show that the image can be both transformed and *not* transformed!
torch.manual_seed(0) #保证每次随机生成的数是相同的
def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
if not isinstance(imgs[0], list): #类型是否相同,不同则对其进行变换
# Make a 2d grid even if there's just 1 row
imgs = [imgs]
num_rows = len(imgs)
num_cols = len(imgs[0]) + with_orig
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False) # 构建排布
for row_idx, row in enumerate(imgs):
row = [orig_img] + row if with_orig else row
for col_idx, img in enumerate(row):
ax = axs[row_idx, col_idx]
ax.imshow(np.asarray(img), **imshow_kwargs)
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
if with_orig:
axs[0, 0].set(title='Original image')
axs[0, 0].title.set_size(8)
if row_title is not None:
for row_idx in range(num_rows):
axs[row_idx, 0].set(ylabel=row_title[row_idx])
plt.tight_layout()
Pad
Pad变换使用一些像素值填充图像边界。
padded_imgs = [T.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot(padded_imgs)
Resize
调整图像大小
resized_imgs = [T.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(resized_imgs)
CenterCrop
在中心裁剪给定图像
center_crops = [T.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot(center_crops)
FiveCrop
FiveCrop变换将给定图像裁剪为四个角和中心裁剪
(top_left, top_right, bottom_left, bottom_right, center) = T.FiveCrop(size=(100, 100))(orig_img)
plot([top_left, top_right, bottom_left, bottom_right, center])
Grayscale
灰度变换将图像转换为灰度
gray_img = T.Grayscale()(orig_img)
plot([gray_img], cmap='gray')
Random transforms
以下变换是随机的,这意味着同一Transformer实例每次变换给定图像时将产生不同的结果
ColorJitter
ColorJitter变换随机更改图像的亮度、饱和度和其他属性。
jitter = T.ColorJitter(brightness=.5, hue=.3)
jitted_imgs = [jitter(orig_img) for _ in range(4)]
plot(jitted_imgs)
GaussianBlur
高斯模糊变换对图像执行高斯模糊变换
blurrer = T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
blurred_imgs = [blurrer(orig_img) for _ in range(4)]
plot(blurred_imgs)
RandomPerspective
随机透视变换对图像执行随机透视变换
perspective_transformer = T.RandomPerspective(distortion_scale=0.6, p=1.0)
perspective_imgs = [perspective_transformer(orig_img) for _ in range(4)]
plot(perspective_imgs)
输出
/home/matti/miniconda3/envs/pytorch-test/lib/python3.8/site-packages/torchvision/transforms/functional.py:594: UserWarning: torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.
torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem).
To get the qr decomposition consider using torch.linalg.qr.
The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field ‘residuals’ of the returned named tuple.
The unpacking of the solution, as in
X, _ = torch.lstsq(B, A).solution[:A.size(1)]
should be replaced with
X = torch.linalg.lstsq(A, B).solution (Triggered internally at /opt/conda/conda-bld/pytorch_1623448216815/work/aten/src/ATen/LegacyTHFunctionsCPU.cpp:389.)
res = torch.lstsq(b_matrix, a_matrix)[0]
RandomRotation
随机角度旋转图像
rotater = T.RandomRotation(degrees=(0, 180))
rotated_imgs = [rotater(orig_img) for _ in range(4)]
plot(rotated_imgs)
RandomAffine
随机仿射变换对图像执行随机仿射变换
affine_transfomer = T.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot(affine_imgs)
RandomCrop
随机裁剪变换在随机位置裁剪图像
cropper = T.RandomCrop(size=(128, 128))
crops = [cropper(orig_img) for _ in range(4)]
plot(crops)
RandomResizedCrop
RandomResizedCrop变换在随机位置裁剪图像,然后将裁剪大小调整为给定大小
resize_cropper = T.RandomResizedCrop(size=(32, 32))
resized_crops = [resize_cropper(orig_img) for _ in range(4)]
plot(resized_crops)
RandomInvert
随机反转变换随机反转给定图像的颜色。
inverter = T.RandomInvert()
invertered_imgs = [inverter(orig_img) for _ in range(4)]
plot(invertered_imgs)
RandomPosterize
RandomPosterize变换通过减少每个颜色通道的位数来随机对图像进行后验证
posterizer = T.RandomPosterize(bits=2)
posterized_imgs = [posterizer(orig_img) for _ in range(4)]
plot(posterized_imgs)
RandomSolarize
RandomSolarize变换通过反转阈值以上的所有像素值来随机对图像进行solarizes
solarizer = T.RandomSolarize(threshold=192.0)
solarized_imgs = [solarizer(orig_img) for _ in range(4)]
plot(solarized_imgs)
RandomAdjustSharpness
RandomAdjustSharpness变换随机调整给定图像的清晰度
sharpness_adjuster = T.RandomAdjustSharpness(sharpness_factor=2)
sharpened_imgs = [sharpness_adjuster(orig_img) for _ in range(4)]
plot(sharpened_imgs)
RandomAutocontrast
RandomAutocontrast变换随机将自动对比应用于给定图像
autocontraster = T.RandomAutocontrast()
autocontrasted_imgs = [autocontraster(orig_img) for _ in range(4)]
plot(autocontrasted_imgs)
RandomEqualize
随机均衡给定图像的直方图
equalizer = T.RandomEqualize()
equalized_imgs = [equalizer(orig_img) for _ in range(4)]
plot(equalized_imgs)
AutoAugment
自动扩充转换根据给定的自动扩充策略自动扩充数据
policies = [T.AutoAugmentPolicy.CIFAR10, T.AutoAugmentPolicy.IMAGENET, T.AutoAugmentPolicy.SVHN]
augmenters = [T.AutoAugment(policy) for policy in policies]
imgs = [
[augmenter(orig_img) for _ in range(4)]
for augmenter in augmenters
]
row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title)
Randomly-applied transforms
在给定概率p的情况下,随机应用一些变换。也就是说,即使使用相同的transformer实例调用,转换后的图像实际上可能与原始图像相同
RandomHorizontalFlip
RandomHorizontalFlip变换以给定的概率执行图像的水平翻转
hflipper = T.RandomHorizontalFlip(p=0.5)
transformed_imgs = [hflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
RandomVerticalFlip
RandomVerticalFlip变换以给定的概率执行图像的垂直翻转
vflipper = T.RandomVerticalFlip(p=0.5)
transformed_imgs = [vflipper(orig_img) for _ in range(4)]
plot(transformed_imgs)
RandomApply
随机应用变换以给定的概率随机应用变换列表
applier = T.RandomApply(transforms=[T.RandomCrop(size=(64, 64))], p=0.5)
transformed_imgs = [applier(orig_img) for _ in range(4)]
plot(transformed_imgs)
更多计算机视觉与图形学相关资料,请关注微信公众号:计算机视觉与图形学实战