详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了

Spatial Transformer Networks

https://blog.jiangzhenyu.xyz/2018/10/06/Spatial-Transformer-Networks/

2018-10-06

目录

思想介绍

Spatial Transformer Networks 引入了一种新的可学习的模块 Spatial Transformer 。这种模块能够对输入的图像(或者 feature map)进行针对性变换,变换的参数根据输入图像计算出来的,而计算过程的参数是通过学习得到的。这种变换赋予网络空间不变性,也就是对于旋转,移动,拉伸或扭曲的图像,一样具有一定的识别能力,这就是通过 Spatial Transformer 模块将原图变为“正常”的样子来实现的。

原论文中,Spatial Transformer 的构造如下图:

Spatial Transformer

首先通过 Localisation net 计算出变换的参数,然后通过 Grid generator 生成变换的栅格,再根据栅格采样获得变换后的图像。中间的 Localisation net 参数是可学习的,需要让导数传播到栅格数据从而传播到网络中,导数传播的推导见下节。

原理推导

以仿射变换变换为例

 

(xsiysi)=T(Gi)=[θ11θ21θ12θ22θ13θ23]⎛⎝⎜xtiyti1⎞⎠⎟(xisyis)=T(Gi)=[θ11θ12θ13θ21θ22θ23](xityit1)

这里 (xti,yti)(xit,yit) 是输出图上第 i 个像素点的位置,对应的,(xsi,ysi)(xis,yis) 就是原图中对应像素点的位置,根据仿射变换的结果,可以得到输出的每个点对应到原图中的哪个位置,这就是栅格数据。

获得栅格数据之后,利用双线性插值即可获得一个位置处的像素值,公式如下:

 

Vci=∑nH∑mWUcmnmax(0,1−|xsi−m|)max(0,1−|ysi−n|)Vic=∑nH∑mWUmncmax(0,1−|xis−m|)max(0,1−|yis−n|)

VciVic 是输出的第 i 个像素点的数值,H 和 W 分别是原图的高和宽,上式就是遍历原图像素点,找到与目标点相邻的像素再插值。

根据上面的公式,就可以计算导数了:

 

∂Vci∂Ucmn=∑nH∑mWmncmax(0,1−|xsi−m|)max(0,1−|ysi−n|)∂Vic∂Umnc=∑nH∑mWmncmax(0,1−|xis−m|)max(0,1−|yis−n|)

上式计算的是对于原图的导数,如果输入的是一张 feature map 的话,导数就可以继续传播下去了。

 

∂Vci∂xsi=∑nH∑mWUcmnmax(0,1−|ysi−n|)×⎛⎝⎜01−1|m−xsi|≥1m≥xsim<xsi∂Vic∂xis=∑nH∑mWUmncmax(0,1−|yis−n|)×(0|m−xis|≥11m≥xis−1m<xis

这样就可以计算对于 (xsi,ysi)(xis,yis) 的导数,从而计算对于θθ矩阵的导数,进而计算对于 Localisation net 参数的导数了。

简单实现

我搬运 GitHub 里的一个 pytorch 实现 STN 的仓库,简化后实现了一个只做旋转变换的 STN 。

输入数据就是在一定角度范围内旋转过的 MNIST 手写数字图片。

模型代码如下:

下面是一个基本的 CNN 模块,后续的 Localisation net 和 分类网络都是用的这个模块。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class CNN(nn.Module):
    def __init__(self, num_output):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, num_output)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

下面是分类模块。

1
2
3
4
5
6
7
8
class ClsNet(nn.Module):

    def __init__(self):
        super(ClsNet, self).__init__()
        self.cnn = CNN(10)

    def forward(self, x):
        return F.log_softmax(self.cnn(x), dim=1)

下面是 Localisation net ,可以看出,仅仅是将 CNN 模块的输出改为 1 维的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# get rotate theta
class LocNet(nn.Module):

    def __init__(self):
        super(LocNet, self).__init__()
        self.cnn = CNN(1)

        # zero init
        bias = torch.zeros(1)
        self.cnn.fc2.bias.data.copy_(bias)
        self.cnn.fc2.weight.data.zero_()

    def forward(self, x):
        batch_size = x.size()[0]
        theta = self.cnn(x)
        return theta.view(batch_size)

下面是旋转模块,根据角度构建仿射矩阵,然后用 pytorch 自带的 affine_grid 生成栅格。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# rotate
class RotateGridGen(nn.Module):

    def __init__(self):
        super(RotateGridGen, self).__init__()

    def forward(self, theta, out_size):
        assert len(theta.size()) == 1
        assert type(out_size) == torch.Size
        batch_size = theta.size()[0]
        affine_mat = theta.new(batch_size, 2, 3)
        affine_mat[:, :, 2] = 0
        affine_mat[:, 0, 0] = torch.cos(theta)
        affine_mat[:, 1, 1] = torch.cos(theta)
        affine_mat[:, 0, 1] = -torch.sin(theta)
        affine_mat[:, 1, 0] = torch.sin(theta)
        grid = F.affine_grid(affine_mat, out_size)
        return grid

下面是整体的网络,用 pytorch 自带的 grid_sample 实现双线性插值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Full net
class STNClsNet(nn.Module):

    def __init__(self):
        super(STNClsNet, self).__init__()

        self.loc_net = LocNet()
        self.rotate = RotateGridGen()
        self.cls_net = ClsNet()

    def forward(self, x):
        batch_size = x.size()[0]
        theta = self.loc_net(x)
        grid = self.rotate(theta, x.size())
        transformed_x = F.grid_sample(x, grid)
        logit = self.cls_net(transformed_x)
        return logit

Localisation net 在测试集上的部分输出结果如下图,上面是原图,下面是变换结果:

Result

可以看到,在大多数输入中,Spatial Transformer 成功地将输入图片给“拧正”了,也就是说,在端到端的分类任务中,它学会了如何将数据旋转至便于后续网络分类的形式。

完整代码见我的仓库

总结

Spatial Transformer Networks 提出了一种全新的模块,在端到端的学习中,可以弱监督地学会对输入图片(及特征图)进行空间变换,以方便后续网络完成任务。此外,习得的变换网络的参数也可以用于其他任务

目录

STN的作用 
1.1 灵感来源 
1.2 什么是STN?
STN的基本架构
Localisation net是如何实现参数的选取的? 
3.1 实现平移 
3.2 实现缩放 
3.3 实现旋转 
3.4 实现剪切 
3.5 小结
Grid generator实现像素点坐标的对应关系 
4.1 为什么会有坐标的问题? 
4.2 仿射变换关系
Sampler实现坐标求解的可微性 
5.1 小数坐标问题的提出 
5.2 解决输出坐标为小数的问题 
5.3 Sampler的数学原理
Spatial Transformer Networks(STN)
STN 实现代码
reference
1.STN的作用
1.1 灵感来源

普通的CNN能够显示的学习平移不变性,以及隐式的学习旋转不变性,但attention model 告诉我们,与其让网络隐式的学习到某种能力,不如为网络设计一个显式的处理模块,专门处理以上的各种变换。因此,DeepMind就设计了Spatial Transformer Layer,简称STL来完成这样的功能。

1.2 什么是STN?

关于平移不变性 ,对于CNN来说,如果移动一张图片中的物体,那应该是不太一样的。假设物体在图像的左上角,我们做卷积,采样都不会改变特征的位置,糟糕的事情在我们把特征平滑后后接入了全连接层,而全连接层本身并不具备 平移不变性 的特征。但是 CNN 有一个采样层,假设某个物体移动了很小的范围,经过采样后,它的输出可能和没有移动的时候是一样的,这是 CNN 可以有小范围的平移不变性 的原因。

 
如图所示,如果是手写数字识别,图中只有一小块是数字,其他大部分地区都是黑色的,或者是小噪音。假如要识别,用Transformer Layer层来对图片数据进行旋转缩放,只取其中的一部分,放到之后然后经过CNN就能识别了。

我们发现,它其实也是一个layer,放在了CNN的前面,用来转换输入的图片数据,其实也可以转换feature map,因为feature map说白了就是浓缩的图片数据,所以Transformer layer也可以放到CNN里面。

2. STN的基本架构
 
如图是Spatial Transformer Networks的结构,主要的部分一共有三个,它们的功能和名称如下:
参数预测:Localisation net
参数预测:Localisation net
坐标映射:Grid generator
坐标映射:Grid generator
像素的采集:Sampler
像素的采集:Sampler
为了让大家对这三个部分有一个先验知识,我先简单介绍一下。
如下图是完成的一个平移的功能,这其实就是Spatial Transformer Networks要做一个工作。 
 
假设左边是Layer l−1Layer l−1的输出,也就是当前要做Transform的输入,最右边为Transform后的结果。这个过程是怎么得到的呢?

假设是一个全连接层,n,m代表输出的值在输出矩阵中的下标,输入的值通过权值w,做一个组合,完成这样的变换。

举个例子,假如要生成al11a11l,那就是将左边矩阵的九个输入元素,全部乘以一个权值,加权相加:
al11=wl1111al−111+wl1112al−112+wl1113al−113+⋯+wl1133al−133
a11l=w1111la11l−1+w1112la12l−1+w1113la13l−1+⋯+w1133la33l−1
这仅仅是al11a11l的值,其他的结果也是这样算出来的,用公式表示称如下这样:
 
通过调整这些权值,达到缩放,平移的目的,其实这就是Transformer的思想。

在这个过程中,我们需要面对三个主要的问题:

这些参数应该怎么确定?
图片的像素点可以当成坐标,在平移过程中怎么实现原图片与平移后图片的坐标映射关系?
参数调整过程中,权值一定不可能都是整数,那输出的坐标有可能是小数,但实际坐标都是整数的,如果实现小数与整数之间的连接?
其实定义的三个部分,就是专门为了解决这几个问题的,接下来我们一个一个看一下怎么解决。

3.Localisation net是如何实现参数的选取的?
3.1 实现平移

 
如果是平移变换,比如从al−111平移到al21a11l−1平移到a21l,得到al21a21l的表示为:
al21=wl2111al−111+wl2112al−112+wl2113al−113+⋯+wl2133al−133
a21l=w2111la11l−1+w2112la12l−1+w2113la13l−1+⋯+w2133la33l−1
我们可以令wl2111=1w2111l=1,其余均为0,不就得到了
al21=1∗al−111
a21l=1∗a11l−1
这就完成了平移了吗?其他的平移也可以用类似的方法来做到。
你可能会问了,那我该怎么得到这些权值呢?总不能人工去看吧! 
当然不会,我们可以设置一个叫做NN这类的东西,把Layer l−1Layer l−1的输出放到NN里,然后生成一系列w。这样听起来好玄乎,但确实是可以这么做的。

3.2 实现缩放

其实缩放也不难,如图所示,如果要把图放大来看,在x→(X2)→x′x→(X2)→x′,y→(X2)→y′y→(X2)→y′将其同时乘以2,就达到了放大的效果了,用矩阵表示如下: 
 
缩小也是同样的原理,如果把这张图放到坐标轴来看,就是如图所示,加上偏执值0.5表示向右,向上同时移动0.5的距离,这就完成了缩小。

3.3 实现旋转

既然前面的平移和缩放都是通过权值来改的,那旋转其实也是。但是旋转应该用什么样的权值呢? 
仔细思考,不难发现,旋转是跟角度有关系的,那什么跟角度有关系呢? 
正弦余弦嘛,为什么它们能做旋转呢? 
一个圆圈的角度是360度,可以通过控制水平和竖直两个方向,就能控制了,如图所示。


由点A旋转θθ度角,到达点B.得到
x′=Rcosα
x′=Rcosα
y′=Rsinα
y′=Rsinα
由A点得
x=Rcos(α+θ)
x=Rcos(α+θ)
y=Rsin(α+θ)
y=Rsin(α+θ)
展开,有:
x=Rcosα cosθ−Rsinα sinθ
x=Rcosα cosθ−Rsinα sinθ
y=Rsinα cosθ+Rcosα sinθ
y=Rsinα cosθ+Rcosα sinθ
把未知数αα替换掉
x=x′ cosθ−y′sinθ
x=x′ cosθ−y′sinθ
y=y′cosθ+x′sinθ
y=y′cosθ+x′sinθ

我们可以简单的理解为cosθ,sinθcosθ,sinθ就是控制这样的方向的,把它当成权值参数,写成矩阵形式,就完成了旋转操作。 
 
注:如果想了解正余弦控制方向是怎么导出的,可以参考计算机图形学的相关书籍,一般都有介绍和数学公式的推导。
3.4 实现剪切

剪切变换相当于将图片沿x和y两个方向拉伸,且x方向拉伸长度与y有关,y方向拉伸长度与x有关,用矩阵形式表示前切变换如下: 


3.5 小结

由此,我们发现所有的这些操作,只需要六个参数[2X3]控制就可以了,所以我们可以把feature map U作为输入,过连续若干层计算(如卷积、FC等),回归出参数θ,在我们的例子中就是一个[2,3]大小的6维仿射变换参数,用于下一步计算;

4.Grid generator实现像素点坐标的对应关系
4.1 为什么会有坐标的问题?

由上面的公式,可以发现,无论如何做旋转,缩放,平移,只用到六个参数就可以了,如图所示: 
 
这6个参数,就足以完成我们需要的几个功能了。

而缩放的本质,其实就是在原样本上采样,拿到对应的像素点,通俗点说,就是输出的图片(i,j)的位置上,要对应输入图片的哪个位置? 
 
如图所示旋转缩放操作,我们把像素点看成是坐标中的一个小方格,输入的图片U∈RHxWxCU∈RHxWxC可以是一张图片,或者feature map,其中H表示高,W表示宽,C表示颜色通道。经过变换Tθ(G)Tθ(G),θθ是上一个部分(Localisation net)生成的参数,生成了图片V∈RH′xW′xCV∈RH′xW′xC,它的像素相当于被贴在了图片的固定位置上,用G=GiG=Gi表示,像素点的位置可以表示为Gi={xti,yti}Gi={xit,yit}这就是我们在这一阶段要确定的坐标。

4.2 仿射变换关系

因此定义了如图的一个坐标矩阵变换关系: 
 
(xti,yti)(xit,yit)是输出的目标图片的坐标,(xsi,ysi)(xis,yis)是原图片的坐标,AθAθ表示仿射关系。

但仔细一点,这有一个非常重要的知识点,千万别混淆,我们的坐标映射关系是:
从目标图片→原图片
从目标图片→原图片
也就是说,坐标的映射关系是从目标图片映射到输入图片上的,为什么这样呢?
作者在论文中写的比较模糊,比较满意的解释是坐标映射的作用,其实是让目标图片在原图片上采样,每次从原图片的不同坐标上采集像素到目标图片上,而且要把目标图片贴满,每次目标图片的坐标都要遍历一遍,是固定的,而采集的原图片的坐标是不固定的,因此用这样的映射。

举个自我感觉很贴切的小例子说一下吧。 


如图所示,假设只有平移变换,这个过程就相当于一个拼图的过程,左图是一些像素点,右图是我们的目标,我们的目标是确定的,目标图的方框是确定的,图像也是确定的,这就是我们的目标,我们要从左边的小方块中拿一个小方块放在右边的空白方框上,因为一开始右边的方框是没有图的,只有坐标,为了确定拿过来的这个小方块应该放在哪里,我们需要遍历一遍右边这个方框的坐标,然后再决定应该放在哪个位置。所以每次从左边拿过来的方块是不固定的,而右边待填充的方框却是固定的,所以定义从
目标图片→原图片
目标图片→原图片
的坐标映射关系更加合理,且方便。
5.Sampler实现坐标求解的可微性
5.1 小数坐标问题的提出

我们可以假设一下我们的权值矩阵的参数是如下这几个数,x,y分别是他们的下标,经过变换后,可以得到如下这样的对应。 


前面举的例子中,权值都是整数,那得到的也必定是整数,如果不是整数呢? 
如图所示: 
 
假如权值是小数,拿得到的值也一定是小数,1.6,2.4,但是没有元素的下标索引是小数呀。那不然取最近吧,那就得到2,2了,也就是与al22a22l对应了。 
那这样的方法能用梯度下降来解吗?

5.2 解决输出坐标为小数的问题

用上面的四舍五入显然是不能进行梯度下降来回传梯度的。 
为什么呢? 
梯度下降是一步一步调整的,而且调整的数值都比较小,哪怕权值参数有小范围的变化,虽然最后的输出也会有小范围的变化,比如一步迭代后,结果有:
1.6→1.64,2.4→2.38
1.6→1.64,2.4→2.38
但是即使有这样的改变,结果依然是:
al−122→al22
a22l−1→a22l
的对应关系没有一点变化,所以output依然没有变,我们没有办法微分了,也就是梯度依然为0呀,梯度为0就没有可学习的空间呀。所以我们需要做一个小小的调整。
仔细思考一下这个问题是什么造成的,我们发现其实在推导SVM的时候,我们也遇到过相同的问题,当时我们如果只是记录那些出界的点的个数,好像也是不能求梯度的,当时我们是用了hing loss,来计算一下出界点到边界的距离,来优化那个距离的,我们这里也类似,我们可以计算一下到输出[1.6,2.4]附近的主要元素,如下所示,计算一下输出的结果与他们的下标的距离,可得:


然后做如下更改:


他们对应的权值都是与结果对应的距离相关的,如果目标图片发生了小范围的变化,这个式子也是可以捕捉到这样的变化的,这样就能用梯度下降法来优化了。

5.3 Sampler的数学原理

论文作者对我们前面的过程给出了非常严密的证明过程,以下是我对论文的转述。

每次变换,相当于从原图片(xsi,ysi)(xis,yis)中,经过仿射变换,确定目标图片的像素点坐标(xti,yti)(xit,yit)的过程,这个过程可以用公式表示为: 
 
(注:把一张图片展开,相当于把矩阵变成坐标向量) 
kernel k表示一种线性插值方法,比如双线性插值,更详细的请参考:(线性插值,双线性插值Bilinear Interpolation算法),ϕx,ϕyϕx,ϕy表示插值函数的参数;UcnmUnmc表示位于颜色通道C中坐标为(n,m)的值。

如果使用双线性插值,可以有: 


为了允许反向传播回传损失,我们可以求对该函数求偏导: 
 
对于ysiyis的偏导也类似。

如果就能实现这一步的梯度计算,而对于∂xsi∂θ,∂ysi∂θ∂xis∂θ,∂yis∂θ的求解也很简单,所以整个过程
Localisation net←Grid generator←Sampler
Localisation net←Grid generator←Sampler
的梯度回转就能走通了。
6.Spatial Transformer Networks(STN)
将这三个组块结合起来,就构成了完整STN网络结构了。 
 
这个网络可以加入到CNN的任意位置,而且相应的计算量也很少。

将 spatial transformers 模块集成到 cnn 网络中,允许网络自动地学习如何进行 feature 
map 的转变,从而有助于降低网络训练中整体的代价。定位网络中输出的值,指明了如何对 
每个训练数据进行转化。

7.STN 实现代码
相应的代码已经有人实现了,我就不做重复工作了。 
请参考:Spatial Transformer Networks 
Torch code 
Theano code

8.reference
原论文 
Spatial Transformer 
Spatial Transformer Networks 
卷积神经网络结构变化——Spatial Transformer Networks 
三十分钟理解:线性插值,双线性插值Bilinear Interpolation算法 
Spatial Transformer Networks 笔记 
李宏毅老师的视频讲解
--------------------- 
作者:黄小猿 
来源:CSDN 
原文:https://blog.csdn.net/qq_39422642/article/details/78870629 
版权声明:本文为博主原创文章,转载请附上博文链接!

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值