Abstract
- CNN缺乏对输入数据保持空间不变的能力。
- 引入了一个新的可学习模块,空间转换器,它允许在网络内对数据进行空间操作。
- 通过为每个输入样本生成适当的变换来积极地对图像(或特征图)进行空间变换,然后在整个特征图上执行转换(非局部),可以包括缩放、裁剪等。
Spatial Transformers
- 图2:空间变压器模块的体系结构。输入特征映射U被传递到一个局部网络,该网络回归变换参数θ。将V上的规则空间网格G转换为采样网格 T θ ( G ) T_θ(G) Tθ(G),该网格应用于U,生成扭曲输出特征图V。
Localisation Network
该部分将Feature转变为 变换矩阵θ,用于下一步Parameterised Sampling Grid。
Parameterised Sampling Grid
- 图3:(a)采样网格为规则网格
G
=
T
I
(
G
)
G=T_I(G)
G=TI(G),其中
I
I
I为恒等变换参数。
(b)采样网格是用仿射变换 T θ ( G ) T_θ(G) Tθ(G)使规则网格变形的结果。
实现细节:
一般来说,常规的CNN输出网格G = {
G
i
G_i
Gi},
G
i
=
(
x
i
t
,
y
i
t
)
G_i = (x^t_i , y^t_i)
Gi=(xit,yit),形成一个输出特性映射
V
∈
R
H
′
×
W
′
×
C
V∈R^{H'×W'×C}
V∈RH′×W′×C。
H
′
和
W
′
H'和W'
H′和W′为网格的高度和宽度,C是通道的数量,输入和输出是相同的。
下面描述坐标点的变换:
其中
(
x
i
s
,
y
i
s
)
(x^s_i, y^s_i)
(xis,yis)为定义样本点的输入特征映射中的源坐标,
A
θ
A_θ
Aθ为仿射变换矩阵;
- 作者使用高度和宽度标准化坐标,这样,当 − 1 ≤ x i t , y i t ≤ 1 −1≤x^t_i, y^t_i≤1 −1≤xit,yit≤1时, x i t , y i t x^t_i, y^t_i xit,yit在输出的空间范围内, − 1 ≤ x i s , y i s ≤ 1 −1≤x^s_i, y^s_i≤1 −1≤xis,yis≤1时, x i s , y i s x^s_i, y^s_i xis,yis在输入的空间范围内(y坐标也类似)。
- 将裁剪、平移、旋转、缩放和倾斜应用于输入特征图,本地化网络只需要生成6个参数( A θ A_θ Aθ的6个元素)。
Differentiable Image Sampling
经过矩阵运算后,绝大多数坐标点数值为float,但是float无法对应feature map上的int坐标值,因此要通过插值计算出对应的像素值,式子如下:
其中
Φ
x
和
Φ
y
Φx和Φy
Φx和Φy为定义图像插值(如双线性)的通用采样核
k
(
)
k()
k()的参数
当
k
(
)
k()
k()为双线性采样核函数:
对于双线性采样(5),其偏导数为:
实验
作者先在MNIST上尝试仿射变换,然后分别在SVHN多数字识别和CUB-200-2011鸟类分类数据集上进行实验。
Try
此处尝试论文中的MNIST试验。
import math
import copy
import torch
import torch.nn.functional as F
def vec_to_perpective_matrix(vec):
# vec rep of the perspective transform has 8 dof; so add 1 for the bottom right of the perspective matrix;
# note network is initialized to transformer layer bias = [1, 0, 0, 0, 1, 0] so no need to add an identity matrix here
out = torch.cat((vec, torch.ones((vec.shape[0], 1), dtype=vec.dtype, device=vec.device)), dim=1).reshape(
vec.shape[0], -1)
return out.view(-1, 3, 3)
def gen_random_perspective_transform(params):
""" generate a batch of 3x3 homography matrices by composing rotation, translation, shear, and projection matrices,
where each samples components from a uniform(-1,1) * multiplicative_factor
"""
batch_size = params.batch_size
# debugging
if params.dict.get('identity_transform_only'):
return torch.eye(3).repeat(batch_size, 1, 1).to(params.device)
I = torch.eye(3).repeat(batch_size, 1, 1)
uniform = torch.distributions.Uniform(-1, 1)
factor = 0.25
c = copy.deepcopy
# rotation component
a = math.pi / 6 * uniform.sample((batch_size,))
R = c(I)
R[:, 0, 0] = torch.cos(a)
R[:, 0, 1] = - torch.sin(a)
R[:, 1, 0] = torch.sin(a)
R[:, 1, 1] = torch.cos(a)
R.to(params.device)
# translation component
tx = factor * uniform.sample((batch_size,))
ty = factor * uniform.sample((batch_size,))
T = c(I)
T[:, 0, 2] = tx
T[:, 1, 2] = ty
T.to(params.device)
# shear component
sx = factor * uniform.sample((batch_size,))
sy = factor * uniform.sample((batch_size,))
A = c(I)
A[:, 0, 1] = sx
A[:, 1, 0] = sy
A.to(params.device)
# projective component
px = uniform.sample((batch_size,))
py = uniform.sample((batch_size,))
P = c(I)
P[:, 2, 0] = px
P[:, 2, 1] = py
P.to(params.device)
# compose the homography
H = R @ T @ P @ A
return H
def apply_transform_to_batch(im_batch_tensor, transform_tensor):
""" apply a geometric transform to a batch of image tensors
args
im_batch_tensor -- torch float tensor of shape (N, C, H, W)
transform_tensor -- torch float tensor of shape (1, 3, 3)
returns
transformed_batch_tensor -- torch float tensor of shape (N, C, H, W)
"""
N, C, H, W = im_batch_tensor.shape
device = im_batch_tensor.device
# torch.nn.functional.grid_sample takes a grid in [-1,1] and interpolates;
# construct grid in homogeneous coordinates
x, y = torch.meshgrid([torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)])
x, y = x.flatten(), y.flatten()
xy_hom = torch.stack([x, y, torch.ones(x.shape[0])], dim=0).unsqueeze(0).to(device)
# tansform the [-1,1] homogeneous coords
xy_transformed = transform_tensor.matmul(xy_hom) # 矩阵相乘:(N, 3, 3) matmul (N, 3, H*W) > (N, 3, H*W)
# convert to inhomogeneous coords -- cf Szeliski eq. 2.21
grid = xy_transformed[:, :2, :] / (xy_transformed[:, 2, :].unsqueeze(1) + 1e-9)
grid = grid.permute(0, 2, 1).reshape(-1, H, W, 2) # (N, H, W, 2); cf torch.functional.grid_sample
grid = grid.expand(N, *grid.shape[1:]) # expand to minibatch
print('H',H,'W',W)
print('grid', grid)
transformed_batch = F.grid_sample(im_batch_tensor, grid, mode='bilinear', align_corners=True)
transformed_batch.transpose_(3, 2)
return transformed_batch
# --------------------
# Test
# --------------------
def test_get_random_perspective_transform():
import matplotlib
matplotlib.use('TkAgg')
import numpy as np
import matplotlib.pyplot as plt
from unittest.mock import Mock
np.random.seed(6)
im = np.zeros((30, 30))
im[10:20, 10:20] = 1
im[20, 20] = 1
imt = np.array([
[ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
[ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
[ 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 4 , 4 , 6 , 2 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 ,219,250,253,196,203,202,199,198, 53, 0 , 0 , 0 ],
[ 0 , 0 , 0 , 1 , 0 , 1 ,62 ,73 ,68 ,62 ,236,104, 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,24 ,253, 3 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,10 ,247,61 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 8 ,250, 4 , 1 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 0 ,21 ,106,185, 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 0 ,219,248, 1 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 0 , 38,254, 75, 1 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 1 ,252,62 , 1 , 1 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 ,121,252,30 , 1 , 1 , 0 , 0 , 0 , 0 , 0 ],
[ 0 , 0 , 0 , 0 , 0 , 8 , 3 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ]])
# get transform
params = Mock()
params.batch_size = 1
params.dict = {'identity_transform_only': False}
params.device = torch.device('cpu')
H = gen_random_perspective_transform(params)
imt = imt[np.newaxis, np.newaxis, ...]
imt = torch.FloatTensor(imt)
imt_transformed = apply_transform_to_batch(imt, H)
fig, axs = plt.subplots(2, 2)
axs[0, 0].imshow(imt.squeeze().numpy(), cmap='gray')
axs[0, 1].imshow(imt_transformed.squeeze().numpy(), cmap='gray')
for ax in plt.gcf().axes:
ax.axis('off')
plt.tight_layout()
plt.show()
if __name__ == '__main__':
test_get_random_perspective_transform()
- 左为原图,右为变换、插值后的图.
小结
- STN与之前学的注意力机制相似,都是从Features中学习到自适应的矩阵,然后作用到原Features上,只是调整的对象不是像素点的值,而是像素点的坐标值(位置)。
- 对图像(Feature)的裁剪、平移、旋转等都可以是对其坐标值进行矩阵运算,因此论文应用6维矩阵对Feature进行仿射变换。