Spatial Transformer Networks

论文:https://arxiv.org/abs/1506.02025

1 核心思想

CNN中使用的最大池化操作,使得网络对输入图像具有了一定的平移不变性。但是,由于一般使用的池化核很小(2 x 2),因此需要使用多个最大池化层才能实现较大的平移不变性。但即便如此,网络的中间层对输入的平移不变性仍然比较小。

作者提出了一个空间变换层(Spatial Transformer Layer,STL)实现对输入的平移、缩放、旋转、裁剪等操作。STL是可微分的,可以加入到CNN中进行前向和反向传播,实现对输入feature map的变换。对于多通道的输入,是对每一个channel应用相同的变换。

STL包含三部分,分别是定位网络(Localisation Network)、grid generator和sampler。

在这里插入图片描述

1.1 定位网络

定位网络的目的是学习对输入变换的参数 θ \theta θ。其输入为 U ∈ R H × W × C U \in R^{H \times W \times C} URH×W×C。定位网络可以是一个小型的全连接网络,也可以是一个小型的卷积网络,但最后应该是一个回归层用于输出预测参数 θ \theta θ

要学习的参数的数量和定义的变换的形式有关,例如,对于仿射变换,要学习的参数量即为6个。

以仿射变换为例,偏移、缩放、旋转都是对输入feature map应用线性变换。

平移:
[ x ′ y ′ 1 ] = [ 1 0 Δ x 0 1 Δ y ] [ x y 1 ] \left[\begin{matrix}x^{'} \\ y^{'} \\ 1 \end{matrix}\right] = \left[\begin{matrix}1 &&0 && \Delta x\\ 0 && 1 && \Delta y\end{matrix}\right] \left[\begin{matrix}x \\ y \\ 1 \end{matrix}\right] xy1=[1001ΔxΔy]xy1

缩放:
[ x ′ y ′ 1 ] = [ s 1 0 0 0 s 2 0 ] [ x y 1 ] \left[\begin{matrix}x^{'} \\ y^{'} \\ 1 \end{matrix}\right] = \left[\begin{matrix}s_1 &&0 && 0\\ 0 && s_2 && 0\end{matrix}\right] \left[\begin{matrix}x \\ y \\ 1 \end{matrix}\right] xy1=[s100s200]xy1

旋转:
[ x ′ y ′ 1 ] = [ cos ⁡ θ − sin ⁡ θ 0 sin ⁡ θ cos ⁡ θ 0 ] [ x y 1 ] \left[\begin{matrix}x^{'} \\ y^{'} \\ 1 \end{matrix}\right] = \left[\begin{matrix}\cos\theta &&-\sin\theta && 0\\ \sin\theta && \cos\theta && 0\end{matrix}\right] \left[\begin{matrix}x \\ y \\ 1 \end{matrix}\right] xy1=[cosθsinθsinθcosθ00]xy1

1.2 grid generator

grid generator是确定输出feature map每一个点映射到输入feature map的哪个点。具体的数学变换形式为:
在这里插入图片描述这里, ( x i t , y i t ) (x^t_i,y^t_i) (xit,yit)表示输出feature map的某个位置, ( x i s , y i s ) (x^s_i,y^s_i) (xis,yis)表示输入feature map的某个位置, A θ A_{\theta} Aθ就是变换矩阵。

这里为什么要将变换矩阵应用于输出feature map而不是输入feature map?一个相对合理的解释是,输出要从输入上拿数据点,而目标是需要填满输出feature map,需要遍历输出feature map且保证其为规则的矩形。如果是对输入进行变换,那么可能造成输出是非规则的,即便我们可以使用其最小外接矩形保证其形状规则,但新添加的点的值无法确定。所以这里用输出到输入的变换,保证输出是规则的且每一个点的值是确定的。

不同形式的变换矩阵,可以实现不同的变换,如下图(a)实现恒等变换,(b)实现仿射变换。
在这里插入图片描述应用于图像注意力区域学习时,使用的变换矩阵可以是如下形式:
在这里插入图片描述
即输入是输出的缩放和平移,缩放又是对各个坐标轴同等尺度的缩放。

1.3 可微分的采样

上面求得了变换矩阵,下面就需要对输入feature map U进行采样得到输出feature map V。采样的函数可以表示成:
在这里插入图片描述 k ( ) k() k()是采样函数, Φ x , Φ y \Phi_x,\Phi_y Φx,Φy是采样函数的参数, U n m c U_{nm}^c Unmc是输入的第c个channel位置(n,m)处的值, V i c V_i^c Vic是输出的第c个channel位置 ( x i t , y i t ) (x_i^t,y_i^t) (xit,yit)处的值。采样时也是对所有的channel应用相同的操作。

理论上,这里可以使用任意的采样函数,只要其对 x i s , y i s x_i^s,y_i^s xis,yis是可微的即可。以双线性差值为例,插值公式可以写作:
在这里插入图片描述那么反向传播公式有:
在这里插入图片描述公式(6)可以看出,输出feature map对输入是可微分的,输出的feature map对采样的坐标也是可微分的,由于 ∂ x i s ∂ θ \frac{\partial x_i^s}{\partial \theta} θxis ∂ y i s ∂ θ \frac{\partial y_i^s}{\partial \theta} θyis也可以求得,那么就可以通过输出对变换参数 θ \theta θ和定位网络求梯度进行训练。

2 应用

将spatial transformer layer加入到网络中,即可实现spatial transformer network。可以对一个输入应用多个spatial transformer layer,让每个STL去学习输入中的一个感兴趣的目标,将各变换的输出进行融合即可实现对多个目标的变换。

使用STL时不需要添加额外的标注信息,使用原来的损失函数即可以让模型学习对输入的变换以取得更好的处理结果。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述应用于分类的时候,可以使用STL对输入进行变换,规范化其形状之后有助于分类。应用于细粒度分类时,可以使用STL作为一种注意力机制,使用多个STL去挖掘图像中多个感兴趣的目标区域。

3 pytorch实现代码

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        # Perform the usual forward pass
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

三个关键点:

  1. 定位网络最后一层是Linear层;
  2. torch.nn.functional.affine_grid(theta,x.size())得到变换后的target的坐标;
  3. torch.nn.functional.grid_sample(x,grid)得到插值的结果。

参考:
https://blog.csdn.net/qq_39422642/article/details/78870629
https://www.jianshu.com/p/723af68beb2e
https://blog.csdn.net/xholes/article/details/80457210

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
空间变换网络(Spatial Transformer Networks,STN)是一种神经网络结构,用于改善卷积神经网络(CNN)的空间不变性。STN可以对经过平移、旋转、缩放和裁剪等操作的图像进行变换,使得网络在变换后的图像上得到与原始图像相同的检测结果,从而提高分类的准确性。STN由三个主要部分组成:局部化网络(Localisation Network)、参数化采样网格(Parameterised Sampling Grid)和可微分图像采样(Differentiable Image Sampling)。 局部化网络是STN的关键组件,它负责从输入图像中学习如何进行变换。局部化网络通常由卷积和全连接层组成,用于估计变换参数。参数化采样网格是一个由坐标映射函数生成的二维网格,它用于定义变换后每个像素在原始图像中的位置。可微分图像采样则是通过应用参数化采样网格来执行图像的变换,并在变换后的图像上进行采样。 使用STN的主要优点是它能够在不改变网络结构的情况下增加空间不变性。这使得网络能够处理更广泛的变换,包括平移、旋转、缩放和裁剪等。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。 关于STN的代码实现,您可以在GitHub上找到一个示例实现。这个实现使用TensorFlow框架,提供了STN网络的完整代码和示例。您可以通过查看该代码来了解如何在您的项目中使用STN。 综上所述,spatial transformer networks(空间变换网络)是一种神经网络结构,用于增加CNN的空间不变性。它包括局部化网络、参数化采样网格和可微分图像采样三个部分。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。在GitHub上有一个使用TensorFlow实现的STN示例代码供参考。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值