Spatial Transformer Networks(空间转换器)及在MNIST中的应用


STN空间变换器在一些论文中会见到,而且因其简单有效、即插即用等特性,应用较多。为了充分理解论文和方便日后使用,这里记录一下STN以及应用在MNIST任务中。

1、Spatial Transformer Networks

空间转换器,空间变换器网络(简称STN)允许神经网络学习如何对输入图像进行空间变换,以增强模型的几何不变性。例如,它可以裁剪感兴趣的区域,缩放和校正图像的方向。它可能是一种有用的机制,因为 CNN 对于旋转和缩放以及更一般的仿射变换不是不变的。

在了解STN之前,你需要先学习仿射变换基础,可以看看图像仿射变换
总之仿射变换就是原图*转换矩阵=仿射后的图,根据转换矩阵的不同,可以实现比如图像平移、缩放、旋转、翻转等等,而STN可以简单理解为通过CNN来自动学习转换矩阵,使得原图和转换矩阵运算后,能够被掰正。
在这里插入图片描述
比如在MNIST分类中,插入STN模块的效果,STN首先会将图像掰正,然后再次进行分类。

2、结构

在这里插入图片描述
上图就是STN模块,首先从UV的右侧面看出,图像经过该模块被纠正(旋转)了。

其代码如下:

def stn(self, x):
    """
    该部分经过卷积和全连接层,从原图拟合出用于仿射变换的转换矩阵,其shape=(2,3)
    :param x: 原图,shape=(1,28,28)
    :return:  仿射变换(掰正)后的图,shape=(1,28,28)
    """
    xs = self.localization(x)
    xs = xs.view(-1, 10 * 3 * 3)
    theta = self.fc_loc(xs)
    theta = theta.view(-1, 2, 3)    # 转换矩阵

    grid = F.affine_grid(theta, x.size())
    x = F.grid_sample(x, grid)

    return x

2.1 Localization net

注意,在代码中Localization net是包括self.localizationself.fc_loc两部分。

该部分完成从输入U中提取特征,拟合出变换矩阵参数θ,具体结构为:
在这里插入图片描述

2.2 Grid generator

其实这部分就只有一行代码:

grid = F.affine_grid(theta, x.size())

通过theta和希望变换后的尺寸s.size()来产生grid。关于这部分其实就是一个API的使用,可以百度一下就可。

2.3 Sampler

这部分也只有一行代码:

x = F.grid_sample(x, grid)

通过的得到的grid来对原图x进行变换,最终得到变换后的图x。

2.2 和2.3 这一步其实opencv也提供了warpAffine这个API来进行仿射变换,但是因为模型训练时tensor最好在GPU上运算,所以直接使用pytorch提供的affine_grid和grid_sample来进行仿射变换。这两个已经不属于深度学习的范畴了,就是稍微复杂一点的普通运算,百度一下即可。

3、实验

实验部分很简单,就是MNIST数字分类。只不过多加了STN模块。

实验 整体结构如下图所示,上面的STN负责将U变换掰正为V,下面的MNIST分类就是简单的CNN网络,这在很多关于MNIST入门教程中都能见到。

或者说就是在普通MNIST分类的CNN网络中,插入了STN模块,实现在使用CNN分类前先矫正MNIST图像数字的功能。
在这里插入图片描述

3.1 STN

STN部分将输入的图片U掰正并输出图片V,图片U和图片V的尺寸是一样的,都是 1 ∗ 28 ∗ 28 1*28*28 12828,下图展示实验训练100epochs后,U(Dataset Images)和V(Transformed Images),可以看到此时STN已经起作用了。图中一些歪七倒八的数字已经被”掰正“了。
在这里插入图片描述

3.2 MNIST分类

底下这部分就是卷积层+全连接层的简单分类网络,这不用多说了,MNIST就是深度学习的“Hello World!”。

3.3 效果展示

Spatial Transformer Networks对MNIST的“纠正”效果(epoch=20时的效果)
在这里插入图片描述
Test Acc曲线(增加epoch=100,能达到99%+)
在这里插入图片描述

4、代码

整个项目放在thgpddl/SpatialTransformerNetworks这里。

5、思考

在本实验中,MNIST中的一些样本通过STN被“掰正”,但是我们知道仿射变换还可以实现比如旋转、平移、缩放、裁剪等效果,那么是否在某些任务中,STN的作用可能时其他仿射效果呢?比如某任务中一些样本图像中目标很小,是否插入STN后的效果是将图像缩放至目标很大呢,这样不就相当于“注意力”了吗?

比如在论文阅读:Spatial Transformer Networks中,就有两个并行的STN,从效果来看,确实有放大的作用(从视觉上来看也可以是平移裁剪等实现),不就相当于一种注意力码?

6、引用

SPATIAL TRANSFORMER NETWORKS TUTORIAL

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
空间变换网络(Spatial Transformer Networks,STN)是一种神经网络结构,用于改善卷积神经网络(CNN)的空间不变性。STN可以对经过平移、旋转、缩放和裁剪等操作的图像进行变换,使得网络在变换后的图像上得到与原始图像相同的检测结果,从而提高分类的准确性。STN由三个主要部分组成:局部化网络(Localisation Network)、参数化采样网格(Parameterised Sampling Grid)和可微分图像采样(Differentiable Image Sampling)。 局部化网络是STN的关键组件,它负责从输入图像学习如何进行变换。局部化网络通常由卷积和全连接层组成,用于估计变换参数。参数化采样网格是一个由坐标映射函数生成的二维网格,它用于定义变换后每个像素在原始图像的位置。可微分图像采样则是通过应用参数化采样网格来执行图像的变换,并在变换后的图像上进行采样。 使用STN的主要优点是它能够在不改变网络结构的情况下增加空间不变性。这使得网络能够处理更广泛的变换,包括平移、旋转、缩放和裁剪等。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。 关于STN的代码实现,您可以在GitHub上找到一个示例实现。这个实现使用TensorFlow框架,提供了STN网络的完整代码和示例。您可以通过查看该代码来了解如何在您的项目使用STN。 综上所述,spatial transformer networks空间变换网络)是一种神经网络结构,用于增加CNN的空间不变性。它包括局部化网络、参数化采样网格和可微分图像采样三个部分。通过引入STN层,CNN可以学习到更鲁棒的特征表示,从而提高分类准确性。在GitHub上有一个使用TensorFlow实现的STN示例代码供参考。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是一个对称矩阵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值