转载声明,借鉴于这里
pytorch里面的仿射变换,写出一些关于自己的理解。
1:基本原理
首先我们先搞明白旋转,其实很简单,如下图。就是基本的三角函数。其中其中
ρ
\rho
ρ表示
(
x
,
y
)
(x, y)
(x,y)点距离原点的距离。
起始点
(
x
,
y
)
(x, y)
(x,y)逆时针旋转
θ
\theta
θ度以后,到
(
x
′
,
y
′
)
(x',y')
(x′,y′)处,所以我们只要给出
θ
\theta
θ,就能通过上述公式,实现旋转。同时我们注意到在上述公式中,尾巴后面有个加[0, 0]
的操作,其实这个操作加[0, 0]
的操作就是实现该点的平移。如果我们把[0, 0]
换成[500, 100]
,并且不进行旋转,这就等于让该点在水平方向上右移500
,垂直方向上往下移动100
。如下图所示。
所以对图像的仿射变换,一共包含六个参数,如下图所示,顾名思义,前两列的参数
[
A
,
B
,
C
,
D
]
[A,B,C,D]
[A,B,C,D]用于实现旋转,最后一列的参数
[
E
,
F
]
[E,F]
[E,F]用于平移。
2:实际操作
好了,原理就那么多,开始上代码。先可视化图片。
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import functional as F
import math
img_path = "img/img.png"
img_torch = transforms.ToTensor()(Image.open(img_path))
plt.imshow(img_torch.numpy().transpose(1, 2, 0))
plt.show()
下面开始仿射变换,进行仿射变换会用到affine_grid和grid_sample这两个函数。这两个函数下面解释,我觉得主要的还是要搞明白这个仿射变换矩阵的 作用是什么,为什么它是一个两行三列的矩阵。旋转结果如下图所示。当然
我这里没有进行平移操作。
# -----------------------
# 角度为-30度,就是顺时针旋转
# -----------------------
angle = -30 * math.pi / 180
# -----------------------
# 构造仿射变换矩阵,前两列用于旋转,
# 最后一列用于平移
# -----------------------
theta = torch.tensor([
[math.cos(angle), math.sin(-angle), 0],
[math.sin(angle), math.cos(angle), 0]
], dtype=torch.float)
# -----------------------
# pytorch中仿射变换需要用到
# affine_grid和grid_sample两个函数
# -----------------------
grid = F.affine_grid(theta.unsqueeze(0), img_torch.unsqueeze(0).size(), align_corners=True)
output = F.grid_sample(img_torch.unsqueeze(0), grid, align_corners=True)
new_img_torch = output[0]
plt.imshow(new_img_torch.numpy().transpose(1, 2, 0))
plt.show()
3:函数的说明
要使用 pytorch 的进行仿射操作,只需要两步:
第一步:创建网格,代码为torch.nn.functional.affine_grid(theta, size)
其中theta
就是我们之前说的仿射变换矩阵,size
为希望仿射变换后图像的大小,其实我们可以通过调节size
设置所得到的图像的大小(相当于resize);
第二步:进行重采样:代码为torch.nn.functional.grid_sample(inputs, grid, mode='bilinear')
,其中inputs
为输入的图像转换后的tensor
,grid
就是第一步的结果,mode
可以设定为双线性插值等。