内涵:STN(spatial transformer network)论文与源码理解

1.引言

  近期阅读了2015年的一篇较为经典的论文"spatial transformer networks(stn)"。本博文是stn阅读心得的记录。在第二小节中,会描述stn的实现细节,包括三大组成构件:localisation network、Grid generator、Sampler。在第三小节中会通过跟踪stn源码(pytorch官方版本)来验证自己的理解正确性。在第四部分作为扩展部分,会尝试从数学角度阐述STN的数学形式并作可导性分析。

2.STN是如何进行的

在这里插入图片描述
                    图1
  spatial transformer networks的提出背景:通常为了使模型在测试阶段spatial invariance, 一种常规的做法是在训练阶段做尽可能丰富的数据扩增操作(eg.shift, crop等)。而stn则是将数据扩增有机的和网络融为一体,达到learnable的效果。从实验结果来看,可较显著的提升(分类)模型的性能。
  stn的核心是如图1所示的spatial transformer模块。

名称说明
U输入特征,为spatial transformer的输入
V输出特征,为spatial transformer的输出
localisation netst模块的三大构件之一,后文会详述
Grid generatorst模块的三大构件之一,后文会详述
Samplerst模块的三大构件之一,后文会详述

              表1

2.1 localisation net

在这里插入图片描述
                    图2

   Localisation net的作用是回归仿射变换的参数 θ \theta θ。图3中的公式是仿射变换操作的通式,二维空间上仿射变换的参数为6个,也即localisation net的输入为 N ∗ C ∗ H ∗ W N*C*H*W NCHW的特征图,输出为 N ∗ 6 N*6 N6
在这里插入图片描述
                    图3
  Localisation net部分的实现就是以Conv层和Linear层构成 说明 1 ^{说明1} 说明1,具体如图4所示。这部分比较直观,就不做赘述。

在这里插入图片描述
                    图4

2.2 Grid generator

在这里插入图片描述
              图5
  这一部分的作用是建立输出特征图中的坐标与输入特征图中的坐标关系。过程像素级别的操作可以用图6来表示
在这里插入图片描述              图6
  关于图6中的公式需要注意两点:

  • 该公式是对 x , y x,y x,y坐标进行操作,而不是feature map的值
  • 人直观的感受可能会写作 s o u r c e ∗ θ − > t a r g e t source*\theta->target sourceθ>target,但如果从实际代码撰写的角度来出发,会更好的理解图6中写法的原因。
      以实际的例子,来描述这一过程:
    在这里插入图片描述
                  图7

  以仿射变换的一种特例,顺时针旋转90度为例。
  对于输出特征图上位置 ( 0 , 0 ) (0,0) (0,0)处的值’2’来自于输入特征图上的 ( 0 , 2 ) (0,2) (0,2)处。
在这里插入图片描述
  对于输出特征图上位置 ( 0 , 1 ) (0,1) (0,1)处的值‘3’来自于输入特征图上的 ( 0 , 1 ) (0,1) (0,1)处。
在这里插入图片描述
  对于输出特征图上位置 ( 0 , 2 ) (0,2) (0,2)处的值‘1’来自于输入特征图上的 ( 0 , 0 ) (0,0) (0,0)处。
在这里插入图片描述
  按照此规律,可以得到输出特征图上点的所有“来源”。

2.3 Sampler

  通过2.2节中描述的Grid generator。可以得到输出特征图上各个value的"来源"矩阵:
在这里插入图片描述
  而Sampler的过程就是基于该“来源”矩阵取索引处值的过程
在这里插入图片描述
  STN实际上的Sampler要比这里描述的复杂一些,因为它还会涉及到一个插值操作。回到STN, 在2.1节中,已经讲明, θ \theta θ是网络学习出来的,旋转只是仿射变换的一种特例。因此大概率计算得到的“来源”并不是一个整数。
  仍旧以一个实际例子来说明,假如在当前iteration,学习得到的 θ \theta θ为:
在这里插入图片描述
  那么对于target 特征图中(0,0)处值的来源为source特征图中的(0.3,0.7)。为了处理这种坐标非整数的情形,就需要利用插值:用其附近的四个整数坐标的value来生成。图8展示了二维插值的计算方式示意图。
在这里插入图片描述
              图8

3.以源码的方式验证自己理解的正确性

  pytorch已经将stn集成,并提供了stn pytorch tutorials。本部分主要是跟踪其中的代码,来完善并验证上述的理解。

3.1 localisalization net相关代码

  这部分直接贴相关核心代码,细节不再赘述。可以较容易的与图4中的内容对应起来。

  • 核心代码段1
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)

*核心代码段2

# Spatial transformer localization-network
self.localization = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=7),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True),
    nn.Conv2d(8, 10, kernel_size=5),
    nn.MaxPool2d(2, stride=2),
    nn.ReLU(True)
)

# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
    nn.Linear(10 * 3 * 3, 32),
    nn.ReLU(True),
    nn.Linear(32, 3 * 2)
)

3.2 Grid generator和Sampler相关代码

grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)

  这部分以实际模型训练中某一iteration的实际例子来进行说明,此时 θ \theta θ
在这里插入图片描述

x . s i z e ( ) x.size() x.size() 64 ∗ 1 ∗ 28 ∗ 28 64*1*28*28 6412828。2.2节以及2.3节中的理解基本是正确的,但pytorch在具体实施的过程中,有两点需要注意:

  1. target中的坐标是被归一化到[-1, 1],然后才利用图6中的公式进行计算(也即计算得到的“根源”坐标也为归一化坐标);
  2. 由于在当前语境下,会用到插值,因此每一个特征被认为是一个1*1的area, 只有area的中心点为特征值(这点看似废话,实际很重要具体可以看网友的讨论);

  因此这里的归一化公式为:
x n o r m = ( 2 ∗ x + 1 ) / s − 1 x_{norm} =(2*x+1)/s-1 xnorm=(2x+1)/s1
y n o r m = ( 2 ∗ y + 1 ) / s − 1 y_{norm} =(2*y+1)/s-1 ynorm=(2y+1)/s1
  反归一化公式为
x = ( ( x n o r m + 1 ) ∗ s − 1 ) / 2 x=((x_{norm}+1)*s-1)/2 x=((xnorm+1)s1)/2
y = ( ( y n o r m + 1 ) ∗ s − 1 ) / 2 y=((y_{norm}+1)*s-1)/2 y=((ynorm+1)s1)/2

  按照2.2中的理解,计算target特征图中(13,5)在source特征图中的来源。
  step1:先利用归一化公式操作得到 ( x n o r m , y n o r m ) = ( − 0.03571 , − 0.6071 ) (x_{norm},y_{norm})=(-0.03571,-0.6071) (xnorm,ynorm)=(0.03571,0.6071);
  step2:与 θ \theta θ相乘,得到输入特征图上的归一化坐标 ( x n , y n ) = ( − 0.1428 , − 0.6859 ) (x_n, y_n)=(-0.1428,-0.6859) (xn,yn)=(0.1428,0.6859),与调试的代码结果一致。
在这里插入图片描述
  step3:对 ( x n , y n ) = ( − 0.1428 , − 0.6859 ) (x_n, y_n)=(-0.1428,-0.6859) (xn,yn)=(0.1428,0.6859)反归一化,得到输入特征图上的非归一化坐标(15.4994, 23.10285)。
  step4:插值的四个坐标对应的特征值为:
在这里插入图片描述
根据图8中公式,可以算得stn输出特征图中x(13,5)处的值为2.4648。而代码打印的结果为2.4819,有一定的误差,但基本与预期相符。
在这里插入图片描述
  以上基本证明了自己对于stn的如何实施的理解正确性。

4.扩展:STN的可导性分析

  第二节,第三节描述了stn的实施细节。但仅仅有这些还不够,我们在设计一个“创新性的”网络结构时,起效的前提或者说理论基础是该模块是differentiable。

4.1 STN的前向公式分析

  论文中给出的前向公式是:
在这里插入图片描述

在阐述该公式时,先暂时忘却这一公式,看一看按照之前的理解,会如何写这一过程:
在这里插入图片描述

在这里插入图片描述
上述公式等价于
在这里插入图片描述
进一步等价于
在这里插入图片描述
再做一点就可以将上述公式中四个sum因子,写成一个通式:
在这里插入图片描述
  再继续想,我们认为和输出特征图上 i t h ith ith点有关的是4个点是一种很自然的想法。但对于pytorch来讲,需要矩阵的操作,不可能仅仅是4个点。因此上述公式,又要进一步进行转换:
在这里插入图片描述
上述公式可以巧妙的将非附近4个点的其他系数计算为0,从而即完成了整个输入特征的计算形式,又达到了实际仅附近4个点参与的效果。

4.2 STN的导数公式分析

  论文中给出的导数公式为:
在这里插入图片描述
  在对前向公式(5)的已经存在的情况下,得到上述两个偏微分公式并不复杂,因此本小节想讲一讲其他的地方。

  1. 公式(7)具有重要的意义:它在对坐标求导数。这是一个值得注意的地方。因为我们之前遇到的一些常规的CNN模块,可能要么很少这样做。
  2. Spatial transofrmer的backward过程再进一步说明一下。

在这里插入图片描述
公式(6)和公式(7)分别对应图中的圈1和圈2。进一步的可以写出圈3处的求导公式(大概形式):
在这里插入图片描述
在这里插入图片描述
可以看到 θ \theta θ是可以学习的。且在圈4时,反向传播已经转换为常规的CNN Bp操作了。

5.反思

  本篇论文给我的启发有4点:

  1. 提供了一种很好的范例,如何将传统的图像处理操作,融为深度学习可学习版本。
  2. 对非feature map的求导学习操作比较少见,本文的该思想做法同样有比较大的启发。
  3. 本文可导性的分析,值得借鉴。深度学习绝不是简简单单的炼丹,其实一旦有了诸如此的数学基础。这样写出来的代码大概率是work的。
  4. 目前来看stn只能适用于分类网络,可以尝试对其进行怎样的修改,推广到目标检测。

6.实践后续

  在近期实际实践的过程中,有以下注意事项:

  1. 新的发现:stn的示例代码中,有一个需要注意的点:localisation net的最后一层是一个linear层,而对linear做特殊的初始化,可以保证stn的转换从等价转换开始,保证了训练的稳定性,这一点很重要。
# Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
  • 6
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
### 回答1: 空间变换网络(Spatial Transformer Network)是一种神经网络模型,它可以对输入图像进行空间变换,从而提高模型的鲁棒性和准确性。该模型可以自动学习如何对输入图像进行旋转、缩放、平移等变换,从而使得模型可以更好地适应不同的输入数据。空间变换网络在计算机视觉领域中得到了广泛的应用,例如图像分类、目标检测、人脸识别等任务。 ### 回答2: 空间变换网络(spatial transformer network)是一种能够自适应地对输入图像进行几何变换的神经网络结构。它最早由Jaderberg等人在2015年提出,是深度学习与计算机视觉中一个重要的技术,目前被广泛应用于机器人视觉、自动驾驶、图像识别、跟踪以及目标定位等领域。 空间变换网络通过学习一个仿射变换矩阵来对输入的图像进行变换,其核心思想是在网络中引入一个可微的位置网格,通过对位置网格上的点进行仿射变换,实现图像的空间变换。 在传统的CNN结构里,其特征提取部分是不变形的,即无论输入图像发生多少位移缩放等操作,神经网络都不能自适应地对这些变化进行相应的调整,因此就不能很好地完成图像识别等任务。而引入空间变换网络后,可以使神经网络能够在学习中自适应地进行对图像缩放、旋转、平移、倾斜等变换,从而提高模型的鲁棒性和识别效果。 空间变换网络的结构一般由三部分组成:特征提取层、坐标生成层和采样网络层。其中,特征提取层可以是任何现有的CNN层,坐标生成层则用来生成仿射变换矩阵(包括平移、旋转、缩放、扭曲等形式),采样网络层则通过仿射变换将输入图像的像素按照特定的网格结构进行采样和变换。 空间变换网络具有以下优点:一、能够适应不同角度、缩放和扭曲程度的图像变换;二、减少了过拟合的风险,因为其能够从小规模的训练数据中学习到更广泛的图像变换范式;三、能够提高卷积神经网络的准确性和鲁棒性,使其具有更好的视觉推理能力;四、具有广泛的应用前景,除了在图像分类、物体识别等领域,还可以应用于姿态识别、图像检索、视觉跟踪等任务。 总之,在深度学习与计算机视觉领域,空间变换网络是非常重要的一个属性,其有效地解决了图像的仿射变换问题,为更广泛的应用提供了重要的方法和技术支持。 ### 回答3: Spatial Transformer NetworkSTN)是一种深度学习中的可学习的空间变换网络,可以自动化地学习如何将输入图像转换或标准化成一个特定形式的输出。STN主要由3个部分构成,分别为定位网络、网格生成器和采样器。 定位网络用于学习如何从输入图像中自动检测出需要进行变换的区域,并进一步学习该区域需要发生的变换类型和程度。然后网格生成器利用学习到的变换参数生成一个新的位置网格,将变换后的特征图从原始输入中分离出来。最后采样器将变换后的特征网格映射回原输入图像,并将其传递给下一层网络进行后续的处理。 STN在深度学习中的应用可以为图像分类、物体检测和目标跟踪等模型提供最优化的输出。STN可以大幅提升网络的稳健性和大数据集的学习能力,尤其是在出现图像旋转、缩放和平移等情况时,STN的适应性更加强大,因为它能够自适应性地应对各种图像变形,这也是它能够在计算机视觉领域中具备很高的使用价值的原因。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

学弟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值