空间变换网络Spatial Transformer Networks(STN)

本文深入解析空间变换网络(STN)的工作原理及其在不同场景中的应用方式,介绍如何利用STN增强模型对图像变形的适应能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

.1简介

.2 空间变换网络原理详解

2.1 概述

2.2 Localisation net

 2.3 Grid generator实现像素点坐标的对应关系

 2.4 Sampler实现坐标求解的可微性

.3 空间变换网络的实际应用

3.1.空间变换网络作为网络的第一层

3.2.空间变换网络插入CNN的中间层

4. 代码分析 

references:


.1简介

STN是一个可以加在网络中间的模块,使得网络能够对图像变形有适用性

比如加入了这个模块训练出来的模型,就会对变形的物体有一定的识别能力

因为模型里包含的参数是对数据进行仿射变换

本文提出了一种叫做空间变换网络(Spatial Transform Networks, STN)的网络模型,该网络不需要关键点的标定,能够根据分类或者其它任务自适应地将数据进行空间变换和对齐(包括平移、缩放、旋转以及其它几何变换等)。在输入数据空间差异较大的情况下,这个网络可以加在现有的卷积网络中,提高分类的准确性。
比如:

 例如对于上图中输入手写字体,我们感兴趣的是黄色框中的包含数字的区域,那么在训练的过程中,学习到的空间变换网络会自动提取黄色框中的局部数据特征,并对框内的数据进行空间变换,得到输出output。

.2 空间变换网络原理详解

2.1 概述

这里写图片描述

第一部分为为”localization net””localization net”网络中的参数则为空间变换网络需要训练的参数;

第二部分就是空间变换即仿射变换。通过该局部网络产生仿射变换系数θ

2.2 Localisation net

如下图是完成的一个平移的功能,这其实就是Spatial Transformer Networks要做一个工作。

这里写图片描述

 

 2.3 Grid generator实现像素点坐标的对应关系

得到变换前后的坐标映射关系

 2.4 Sampler实现坐标求解的可微性

如下所示,计算一下输出的结果与他们的下标的距离,可得: 

然后做如下更改:

 数学公式论证

.3 空间变换网络的实际应用

以上讲解的是空间变换网络的理解,那么在实际应用中,我们该如何添加空间变换网络到我们自己的网络中呢?接下来重点讲解空间变换网络的应用。

3.1.空间变换网络作为网络的第一层

空间变换网络可以直接作为网络的第一层,即Localisation Net的输入为input,从而直接对输入进行仿射变换,对于Localisation Net的设计,可以根据输入input的大小设计Localisation Net为全连接层或卷积层.

例如对于手写字体,输入图片大小为40x40,即input=[batch_size,1600],那么我们可以设计Localisation Net包含两个全连接层,第一个全连接层w1=[1600,20],b1=[20],第二个全连接层w2=[20,6],b2=[6],则第二个全连接层的输出为[batch_size,6],即为仿射变换系数。

3.2.空间变换网络插入CNN的中间层

空间变换网络还可以添加在CNN的中间层,可以直接将空间变换网络插入conv或者max-pooling层的前面或者后面。此外,还可以在CNN的同一层插入多个空间变换网络,下面给出空间变换网络插入CNN的手写字体网络结构图:

 这里写图片描述

上图中第一个空间变换网络ST1作用于输入图像,直接对输入图像进行空间变换,第二、三个空间变换网络ST2a,ST2b作用于conv1,用于对第一层的卷积特征进行空间变换,而ST3用于对更深层的卷积特征进行空间变换。 

 由于空间变换网络能够自动提取局部区域特征,因此在网络的同一层插入父哦个空间变换网络可以提取多个局部区域特征,从而可以结合多个局部区域特征进行分类:

如下如的网络是实现两张输入的图片中的手写字体相加,在网络的第一层插入两层空间变换网络ST1,ST2,并将其直接作用语输入图像。图中第三列为空间变换结果,有图可知,网络ST1,ST2分别提取了输入手写字体的不同区域的特征

这里写图片描述

4. 代码分析 

首先看一仿射变换的代码实现,代码的实现如上所述,首先由函数_meshgrid生成输出V的坐标位置点grid,在通过仿射变换系数theta对grid进行仿射变换得到U中对于位置坐标点T_g,之后对T_g进行双线性插值,并复制插值后的U中的坐标点的像素值到V中,得到输出V。


    def transform(theta, input_dim, out_size):
        with tf.variable_scope('_transform'):
            num_batch = tf.shape(input_dim)[0]
            height = tf.shape(input_dim)[1]
            width = tf.shape(input_dim)[2]
            num_channels = tf.shape(input_dim)[3]
            theta = tf.reshape(theta, (-1, 2, 3))
            theta = tf.cast(theta, 'float32')

            # grid of (x_t, y_t, 1), eq (1) in ref [1]
            height_f = tf.cast(height, 'float32')
            width_f = tf.cast(width, 'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            grid = _meshgrid(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.pack([num_batch]))
            grid = tf.reshape(grid, tf.pack([num_batch, 3, -1]))#得到输出坐标位置点

            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
            T_g = tf.batch_matmul(theta, grid)#仿射变换
            x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])#
            y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
            x_s_flat = tf.reshape(x_s, [-1])
            y_s_flat = tf.reshape(y_s, [-1])

            input_transformed = _interpolate(
                input_dim, x_s_flat, y_s_flat,
                out_size)#插值,并得到输出

            output = tf.reshape(
                input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
            return output

完整代码如下:

https://github.com/tensorflow/models/blob/master/transformer/cluttered_mnist.py

references:

详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn

<think>好的,我需要回答用户关于Spatial Transformer Networks的中文名称的问题。首先,我应该查阅提供的引用资料,看看是否有直接提到中文翻译的地方。在引用[2]中,提到了“组合STNs (c-STNs)”和“空间变压器模块”,这里“STNs”应该就是Spatial Transformer Networks的缩写,所以中文名可能是“空间变压器网络”。另外,用户提供的引用[3]中也提到了“空间变压器模块”,进一步确认了“空间变压器”这个翻译。再检查其他引用,没有发现其他可能的翻译。因此,可以确定中文名称是“空间变压器网络”或者“空间变换网络”,但根据引用中的用词,“空间变压器网络”更准确。同时,可能需要确认常见的学术翻译习惯,比如“Transformer”通常译为“变压器”或“变换器”,但在神经网络领域,可能更倾向于“变换器”。不过引用[2]直接使用了“空间变压器模块”,所以应遵循引用中的翻译。最后,生成相关问题时要围绕STN的应用、原理等展开。</think>### Spatial Transformer Networks的中文名称 Spatial Transformer NetworksSTNs)的中文译名为**空间变换网络**或**空间变压器网络**。这一名称直接体现了其核心功能:通过学习几何变换参数,对输入数据进行空间层面的自适应调整,从而提高模型对形变、旋转、缩放等空间变化的鲁棒性[^2]。在文献中,“Transformer”一词常被译为“变换器”或“变压器”,而“Spatial”则对应“空间的”含义。例如,引用[2]中提到的“空间变压器模块”即指STN的核心组件。 ### 关键功能与原理 STN通过以下步骤实现空间变换: 1. **定位网络(Localisation Network)**:通过卷积层或全连接层预测几何变换参数(如仿射变换参数)。 2. **网格生成(Grid Generator)**:根据参数生成采样网格,映射输入到输出坐标。 3. **可微采样(Differentiable Sampling)**:使用双线性插值等可微操作生成变换后的特征图。 这种显式的几何变换模块可与传统卷积网络结合,增强模型的空间不变性[^3]。 ### 示例代码片段 ```python # 简化的STN实现(基于PyTorch) import torch import torch.nn as nn class SpatialTransformer(nn.Module): def __init__(self): super().__init__() self.localization = nn.Sequential( nn.Conv2d(1, 8, kernel_size=7), nn.MaxPool2d(2, stride=2), nn.ReLU(), nn.Conv2d(8, 10, kernel_size=5), nn.MaxPool2d(2, stride=2), nn.ReLU() ) self.fc = nn.Sequential( nn.Linear(10 * 3 * 3, 32), nn.ReLU(), nn.Linear(32, 6) # 输出仿射变换参数(2x3矩阵) ) def forward(self, x): theta = self.fc(self.localization(x).view(x.size(0), -1)) theta = theta.view(-1, 2, 3) grid = nn.functional.affine_grid(theta, x.size()) return nn.functional.grid_sample(x, grid) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值