STN代码理解——NIPS2015
论文链接:https://proceedings.neurips.cc/paper_files/paper/2015/file/33ceb07bf4eeb3da587e268d663aba1a-Paper.pdf
代码链接:https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html#
核心代码:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
1、网络架构
由Spatial Transformer Networks的网络结构图可知,它主要由Localisation net、Grid generator和Sampler三部分组成。
注意:这三部分都是可微的,整个过程可以实现Localisation net<-Grid generator<-Sampler的梯度回传。
1.1 Localisation net
Localisation net用来预测仿射矩阵的参数,也就是图中的
θ
\theta
θ,矩阵大小为
3
×
2
3\times2
3×2。对应代码中的self.localization
和self.fc_loc
。
1.2 Grid generator
Grid generator借助Localisation net学习到的仿射矩阵的参数来实现上一层特征图到下一层特征图上像素点的坐标映射
1.3 Sampler
Sampler 是指可微图像采样(Differentiable Image Sampling)。由Grid generator获得的像素点坐标是小数,但我们需要的坐标是整数,这里的Sampler就是作者提出的将小数转换成整数并且可微的一种方式。
看一个例子,如下图所示,网络学习到的仿射矩阵的参数为 [ 0 0.5 0.6 1 0 0.4 ] \left[ \begin{matrix} 0 & 0.5 & 0.6 \\ 1 & 0 & 0.4 \\ \end{matrix} \right] [010.500.60.4],我们想知道第L层(2, 2)处的像素点值应该对应第L-1层哪一坐标上的像素点值,用放射矩阵对(2, 2)计算后获得坐标(1.6, 2.4),该坐标不是整数。
我们想将其化为整数,最简单的方式是四舍五入,将(1.6, 2.4)四舍五入为(2, 2),直接将第L-1层(2, 2)处的像素点值作为第L层(2, 2)处的像素点值,但四舍五入后的结果不可微分,也就无法反传更新参数,这个方法不可取。
作者没有直接拿L-1层的某一位置的值直接放到第L层。如下图所示,作者没有直接拿L-1层(2, 2)值放到第L层(2, 2)处,而是将L-1层位置离(1.6, 2.4)最近的四个像素值加权求得第L层(2, 2)处的值。这四个像素值加权时的权重由其位置到(1.6, 2.4)的距离求得。
Grid generator和Sampler的实现对应代码中的grid = F.affine_grid(theta, x.size())
和x = F.grid_sample(x, grid)
小节参考:https://www.bilibili.com/video/BV1Ct411T7Ur/?spm_id_from=333.337.search-card.all.click&vd_source=16e43fe0ec489305de9912bf4c941ded
2、补充知识点
看这篇论文时发现自己对仿射变换的认识比较薄弱,查阅了一些相关资料,在这里记录一些我觉得有用的知识点。
2.1 线性变换与仿射变换
仿射变换 = 线性变换 + 平移
ps:平移属于非线性变换
描述\变换 | 线性变换 | 仿射变换 |
---|---|---|
几何描述 | 变换前是直线,变换后依然是直线
直线比例保持不变 变换前后原点不变 | 变换前是直线,变换后依然是直线
直线比例保持不变 |
代数描述 | 通过矩阵乘法实现 | 通过矩阵乘法和矩阵加法实现 |
线性变换的具体实施方法是用一个矩阵左乘待变换的向量。(变换的向量左乘一个矩阵也可以)
通过线性变化完成仿射变换
仿射变换
y
⃗
=
A
x
⃗
+
b
⃗
\vec{y} =A\vec{x}+\vec{b}
y=Ax+b
可以写作
线性变化
[
y
⃗
1
]
=
[
A
b
⃗
0
1
]
[
x
⃗
1
]
\left[ \begin{matrix} \vec{y} \\ 1 \\ \end{matrix} \right] = \left[ \begin{matrix} A & \vec{b} \\ 0 & 1 \\ \end{matrix} \right] \left[ \begin{matrix} \vec{x} \\ 1 \\ \end{matrix} \right]
[y1]=[A0b1][x1]
增加一个维度后,就可以在高维度通过线性变换来完成低维度的仿射变换。
小节参考:https://blog.csdn.net/studyeboy/article/details/113540306
2.2 图形变换的基础概念
图形变换是计算机图形学领域内的重要内容之一。计算机图形学中的图形变换主要有几何变换、坐标变换和观察变换。这些变换有着不同的作用,却又紧密联系在一起。
2.2.1 几何变换
一般来说,图形的几何变换是指对图形的几何信息经过平移、比例、旋转等变换后产生新的图形,即图形在方向、尺寸和形状方面的变换,需要改变图形对象的坐标描述,
2.2.2 齐次坐标
齐次坐标技术是从几何学发展起来的。齐次坐标表示在投影几何中是一种证明理论的工具。有时在n维空间中较难解决的问题,变换到n+1维空间就比较容易得到解决,通过将齐次坐标技术应用到计算机图形学中,是图形变换转化为表示图形的点集矩阵与某一变换矩阵相乘这一单一问题,因而可以借助计算机的高速计算功能,很快得到变换后的图形,从而为高速动态的计算机图形显示提供了可能性。
所谓的齐次坐标表示就是用n+1维向量表示n维向量。例如,二维平面上的点 P ( x , y ) P(x, y) P(x,y)的齐次坐标表示为 ( h x , h y , h ) (hx, hy, h) (hx,hy,h)。这里 h h h是任意不为0的比例系数。规范化齐次坐标表示就是 h = 1 h=1 h=1的齐次坐标表示。类似的,三维空间中的坐标点 P ( x , y , z ) P(x, y, z) P(x,y,z)的齐次坐标表示为 ( h x , h y , h z , h ) (hx, hy, hz, h) (hx,hy,hz,h)。
2.2.3 二维变换矩阵
引入规范化齐次坐标表示后,点P可以用一个矩阵表示,这个矩阵可以是行向量矩阵,也可以是列向量矩阵,即
这里用行向量矩阵形式。这样,二维空间中某点的变换可以表示为点的齐次坐标矩阵与三阶矩阵
T
2
D
T_{2D}
T2D相乘,即
其中,
T
2
D
T_{2D}
T2D被称为二维齐次坐标变换矩阵,简称二维变换矩阵。
从功能上可以将 T 2 D T_{2D} T2D分为4个子矩阵。其中, T 1 = [ a b c d ] T_1= \left[\begin{matrix}a&b \\c&d\\\end{matrix}\right] T1=[acbd]是对图形进行比例、旋转、对称、错切等变换; T 2 = [ l m ] T_2= \left[\begin{matrix}l&m\end{matrix}\right] T2=[lm]是对图形进行平移变换; T 3 = [ p q ] T_3= \left[\begin{matrix}p\\q\end{matrix}\right] T3=[pq]是对图形进行投影变换; T 4 = [ s ] T_4= [s] T4=[s]是对图形进行整体比例变换。
小节参考:《计算机图形学基础》