目录
2.3 Grid generator实现像素点坐标的对应关系
.1简介
STN是一个可以加在网络中间的模块,使得网络能够对图像变形有适用性
比如加入了这个模块训练出来的模型,就会对变形的物体有一定的识别能力
因为模型里包含的参数是对数据进行仿射变换
本文提出了一种叫做空间变换网络(Spatial Transform Networks, STN)的网络模型,该网络不需要关键点的标定,能够根据分类或者其它任务自适应地将数据进行空间变换和对齐(包括平移、缩放、旋转以及其它几何变换等)。在输入数据空间差异较大的情况下,这个网络可以加在现有的卷积网络中,提高分类的准确性。
比如:
例如对于上图中输入手写字体,我们感兴趣的是黄色框中的包含数字的区域,那么在训练的过程中,学习到的空间变换网络会自动提取黄色框中的局部数据特征,并对框内的数据进行空间变换,得到输出output。
.2 空间变换网络原理详解
2.1 概述
第一部分为为”localization net””localization net”网络中的参数则为空间变换网络需要训练的参数;
第二部分就是空间变换即仿射变换。通过该局部网络产生仿射变换系数θ
2.2 Localisation net
如下图是完成的一个平移的功能,这其实就是Spatial Transformer Networks要做一个工作。
2.3 Grid generator实现像素点坐标的对应关系
得到变换前后的坐标映射关系
2.4 Sampler实现坐标求解的可微性
如下所示,计算一下输出的结果与他们的下标的距离,可得:
然后做如下更改:
数学公式论证
.3 空间变换网络的实际应用
以上讲解的是空间变换网络的理解,那么在实际应用中,我们该如何添加空间变换网络到我们自己的网络中呢?接下来重点讲解空间变换网络的应用。
3.1.空间变换网络作为网络的第一层
空间变换网络可以直接作为网络的第一层,即Localisation Net的输入为input,从而直接对输入进行仿射变换,对于Localisation Net的设计,可以根据输入input的大小设计Localisation Net为全连接层或卷积层.
例如对于手写字体,输入图片大小为40x40,即input=[batch_size,1600],那么我们可以设计Localisation Net包含两个全连接层,第一个全连接层w1=[1600,20],b1=[20],第二个全连接层w2=[20,6],b2=[6],则第二个全连接层的输出为[batch_size,6],即为仿射变换系数。
3.2.空间变换网络插入CNN的中间层
空间变换网络还可以添加在CNN的中间层,可以直接将空间变换网络插入conv或者max-pooling层的前面或者后面。此外,还可以在CNN的同一层插入多个空间变换网络,下面给出空间变换网络插入CNN的手写字体网络结构图:
上图中第一个空间变换网络ST1作用于输入图像,直接对输入图像进行空间变换,第二、三个空间变换网络ST2a,ST2b作用于conv1,用于对第一层的卷积特征进行空间变换,而ST3用于对更深层的卷积特征进行空间变换。
由于空间变换网络能够自动提取局部区域特征,因此在网络的同一层插入父哦个空间变换网络可以提取多个局部区域特征,从而可以结合多个局部区域特征进行分类:
如下如的网络是实现两张输入的图片中的手写字体相加,在网络的第一层插入两层空间变换网络ST1,ST2,并将其直接作用语输入图像。图中第三列为空间变换结果,有图可知,网络ST1,ST2分别提取了输入手写字体的不同区域的特征
4. 代码分析
首先看一仿射变换的代码实现,代码的实现如上所述,首先由函数_meshgrid生成输出V的坐标位置点grid,在通过仿射变换系数theta对grid进行仿射变换得到U中对于位置坐标点T_g,之后对T_g进行双线性插值,并复制插值后的U中的坐标点的像素值到V中,得到输出V。
def transform(theta, input_dim, out_size):
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.pack([num_batch]))
grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))#得到输出坐标位置点
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
T_g = tf.batch_matmul(theta, grid)#仿射变换
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])#
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
out_size)#插值,并得到输出
output = tf.reshape(
input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
return output
完整代码如下:
https://github.com/tensorflow/models/blob/master/transformer/cluttered_mnist.py
references:
详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn