目录
前沿
空间变换网络(Spatial Transformer Networks,简称STN)是一种深度学习模型,旨在增强网络对几何变换的适应能力。STN是由Max Jaderberg等人在2015年提出的,其核心思想是在传统的卷积神经网络(CNN)中嵌入一个可学习的模块,该模块能够显式地对输入图像进行空间变换,从而使得网络能够对输入图像的几何变形具有更好的适应性。STN的引入使得网络能够自动进行图像的校正,例如旋转、缩放、剪切等,这在很多视觉任务中是非常有用的,如图像识别、目标检测和图像分割等。
STN可以使模型学习平移、缩放、旋转和更通用的扭曲的不变性。(二维空间变换网络)
CNN分类时,通常需要考虑输入样本的局部性、平移不变性、缩小不变性,旋转不变性等,以提高分类的准确度。这些不变性的本质就是图像处理的经典方法,即图像的裁剪、平移、缩放、旋转,而这些方法实际上就是对图像进行空间坐标变换,我们所熟悉的一种空间变换就是仿射变换,图像的仿射变换公式可以表示如下:
STN相当于这种显示的深度学习模块,用于学习图像的平移不变性、旋转不变性等信息。
STN基础架构
STN可以通过为每个输入样本生成适当的变换来主动对图像(或特征图)进行空间变换。然后在整个特征图上(非局部)执行变换,并且可以包括缩放、裁剪、旋转以及非刚性变形。这使得包含空间变换器的网络不仅可以选择图像中最相关(注意力)的区域,还可以将这些区域转换为规范的预期姿势,以简化后续层中的推理。
STN的工作流程可以分为三个主要部分:
-
定位网络(Localization Network):这一部分是STN的核心,其任务是学习输入图像的空间变换参数。定位网络可以是任意的网络结构,它接受输入图像,并输出空间变换所需的参数。这些参数定义了一个变换矩阵,用于调整图像的空间位置。
-
网格生成器(Grid Generator):接收定位网络输出的变换参数,并生成一个对应于输出图像的坐标网格。这个坐标网格对应于输入图像中的每一个像素位置。
-
采样器(Sampler):根据网格生成器的输出坐标网格,从输入图像中采样像素来产生变换后的输出图像。这一步骤确保了图像的空间变换是可微分的,从而可以通过反向传播算法进行训练。
STN目标是通过操作数据而不是特征提取器来实现不变表示。
流程:输入特征图U首先传递到定位网络,回归转化得到参数𝜃(得到用于放射变换的参数𝜃)。然后由网格生成器通过𝜃和定义的变换方式寻找输出和输入特征的映射关系𝑇𝜃𝐺。采样器结合位置映射和变换参数对输入特征进行选择并结合双线性插值得到输出。
各层详解
Localisation net
- 功能:定位网络的主要任务是预测空间变换的参数。根据输入图像,这个网络会输出一组参数,这些参数定义了一个空间变换,可以是平移、旋转、缩放等或者更复杂的仿射变换或者非线性变换。
- 结构:定位网络通常是一个小型的卷积神经网络或全连接网络,其具体结构可以根据任务的复杂度和输入数据的特性来定制。网络的输出大小是固定的,对应于特定变换所需的参数数量。
Grid generator
- 功能:网格生成器接收定位网络预测的变换参数,并生成一个坐标网格,该网格代表了输入图像中每个像素映射到输出图像中的新位置。
- 原理:对于每个输出图像的像素位置,网格生成器使用变换参数来计算对应的输入图像中的坐标。这一过程通常涉及到矩阵运算,用于实现平移、旋转、缩放等仿射变换。
负责根据定位网络输出的变换参数创建一个采样网格,该网格决定了如何从输入特征图(U)中采样以产生变换后的输出特征图(V)
网格生成器的工作基于一个简单的概念:为输出特征图(V)上的每个像素位置创建一个对应的输入特征图(U)中的采样位置。
步骤
Grid generator利用计算得到的𝜃,对Feature map进行相应的空间变换(即放射变换)。
当输入通过定位网路(Localisation Network)和网格生成器(Grid Generator)后,对输入的特征图(Feature Map)进行了放射变换,得到变换后的采样网格。此时的采样网格指定了输入特征图中哪些位置(即像素点)应该被采样来生成输出特征图。由于变换后的采样点可能位于输入特征图的非整数坐标上,因此需要一种能够处理这种情况的采样方法,即采样器。
(为了训练神经网络,我们需要能够计算损失函数相对于网络参数的梯度,这一过程通过反向传播算法实现。这要求网络中每一步操作都必须是可微分的,以便梯度可以从网络输出传递回网络输入。当涉及到STN的采样器时,我们需要确保特征得分(即输出特征图中的像素值)对于特征位置(即输入特征图中的采样位置)的偏导数是可计算的。这样,就可以根据输出特征图相对于输入特征图中位置的改变,来调整定位网络的参数,实现网络的端到端训练。)
Sampler
- 功能:采样器根据网格生成器提供的坐标网格从输入图像中采样像素值,生成变换后的输出图像。
- 细节:采样过程需要处理非整数像素坐标的情况,这通常通过双线性插值或其他插值方法来实现,以便从周围的像素值中估算出新像素的值。
采样器的主要作用是利用网格生成器产生的变换后的采样网格,从输入特征图(U)中采样出对应的像素值,以生成变换后的输出特征图(V)。这一过程是通过一种称为图像采样的技术实现的,保证了空间变换的可微分,从而允许梯度通过变换流回网络,实现端到端的训练。
构造一种Position->Score的映射,该映射通过采样器来实现,例如通过双线性采样根据非整数坐标处的采样值,估计邻近整数坐标处的像素得分值(这种映射是可微的,因此可通过计算得分相对位置处的梯度)。
Feature map通过上述两个模块后,网格生成器输出的特征图的每个像素点都将对应Feature map中的某个像素。然而,feature score对于feature position的偏导数无法计算,因此需要构造一种position->score的映射,且该映射具有可导的性质,从而满足反向传播的条件。即每一个输出的位置i,都有:
采样器的工作原理
采样器根据变换后的采样网格𝑇𝜃𝐺中的位置,从输入特征图(U)中提取像素值,以产生输出特征图(V)。每一个位置在采样网格中指示了输入特征图(U)上的一个具体采样点。采样过程中,可能会遇到非整数坐标的采样点,因此需要一种插值方法来确定这些点的像素值。常见的插值方法包括最近邻插值、双线性插值等。
工作流程总结
- 输入图像:STN接收原始图像作为输入。
- 定位网络:计算出空间变换所需的参数。
- 网格生成器:根据这些参数创建一个新的坐标网格,指示如何将输入图像变换到输出空间。
- 采样器:利用这个网格从原始图像中采样,生成变换后的图像。
STN的妙处在于其灵活性和泛用性,它可以应用于各种视觉任务中,如图像分类、物体检测和图像分割等,特别是在处理图像几何变换方面表现出色。通过学习空间变换,STN能够让网络更加关注于图像的重要特征,从而提高整体的性能和鲁棒性。