【论文笔记】STANet:基于孪生神经网络的时空注意力变化检测模型

本文介绍了一种基于孪生神经网络的时空注意力变化检测模型STANet,该模型能够有效处理不同尺度的物体变化检测问题。通过引入自注意力机制和多尺度特征提取,STANet能够在不同时间和空间位置间建立联系,提高变化检测的准确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文是论文《A Spatial-Temporal Attention-Based Method and a New Dataset for Remote Sensing Image Change Detection》的阅读笔记,由于原文比较长,本文有很多省略,着重介绍该模型是怎么运作的。

一、相关工作

文章针对遥感图像变化检测问题提出了一个基于孪生神经网络的时空注意力变化检测模型STANet,其中的自注意力模块可以计算任意两张拍摄于不同日期和位置的图像的注意力权重,并产生更具辨别性的特征。考虑到物体可能具有不同的大小,文章还将图像分割成了多尺度的子区域,并在每个子区域中引入了自注意力机制。此外还创建了新的变化检测数据集LEVIR-CD。

在这里插入图片描述

上图a是时空注意力的示意图,b是图像误配准的情况。

大多数基于机器学习的变化检测方法都包括两步:单元分析和变化识别。单元分析是分析单元的原数据的特征,分析单元可以分为图像像素和图像物体两大类。变化识别使用手工或学习到的规则来计算特征差图并使用阈值分割得到不同的变化区域。

基于深度学习的变化检测方法主要可以分为两类:基于度量的方法和基于分类的方法。基于度量的方法通过对比图像之间参数化的距离来决定是否发生变化。每一对点之间的特征的度量表示是否发生了变化。基于分类的方法通过对提取到的图像特征进行分类,从而识别变化的类别。STANet属于基于度量的方法。

二、方法和网络结构

1. motivation

文章的motivation如下:

  • 变化检测数据是有时间维度和空间维度的光谱向量组成的,开发不同时空位置之间的关系可以提升变化检测方法的效果。因此提出了时空自注意力机制。
  • 由于变换物体可能具有不同的大小,从一个合适的范围内提取特征可以更好地表示一定尺度的对象。可以通过从不同大小区域提取得到的特征结合起来以获得不同尺度的特征。因此将图像分割成了多尺度的子区域,并在每个子区域中引入了自注意力机制。

2. 网络结构

文章设计了两种自注意力模块,一是基本的时空注意力模块BAM,二是金字塔时空注意力模块PAM。BAM任意两个位置之间的时空独立性注意力权重,并通过时空中所有位置特征的加权和来计算每个位置的响应。PAM将BAM嵌入得到一个金字塔结构以产生多尺度的注意力表示。

在这里插入图片描述

上图是STANet的结构示意图,图中的 C × H × W C\times H\times W C×H×W C C C是通道数, H H H W W W是特征图的高和宽。

STANet包括特征提取器、注意力模块、度量模块三部分。首先两张图像被喂入到两个特征提取器中获得两个特征图 X ( 1 ) X^{(1)} X(1) X ( 2 ) X^{(2)} X(2),经过注意力模块的处理后得到两张注意力特征图 Z ( 1 ) Z^{(1)} Z(1) Z ( 2 ) Z^{(2)} Z(2),在将注意力特征图resize到输入图像大小之后,度量模块会计算两个注意力特征图的每个像素对之间的距离,并产生一个距离图 D D D,然后通过简单的阈值法得到最终的变化标签图 P P P

特征提取器

特征提取器中用到了ResNet-18,由于ResNet是用来进行图像分类任务而变化检测是密集分类任务,所以省略了ResNet中的全局池化层和全连接层。

BAM

在BAM中,特征图 X X X首先通过三个不同的 1 × 1 1\times1 1×1的卷积层得到三个特征向量 Q , K , V Q, K, V Q,K,V,分别表示查询、键和值。然后对其reshape得到矩阵 Q ˉ , K ˉ , V ˉ \bar Q,\bar K,\bar V Qˉ,Kˉ,Vˉ,并使用转置后的 K ˉ \bar K Kˉ Q ˉ \bar Q Qˉ进行矩阵乘法并使用softmax计算一个相似矩阵 A A A,该相似矩阵与 V ˉ \bar V Vˉ进行矩阵乘法得到输出矩阵 Y ˉ \bar Y Yˉ,对其进行reshape得到注意力 Y Y Y Y Y Y X X X进行像素级乘法得到最终的注意力特征图 Z Z Z

PAM

而PAM有4个分支,每个分支将特征图 X X X分成了不同大小的子区域,并在每个子区域中应用BAM,每个分支的输出拼接起来和输入大小相同,将4个分支的输出concate起来并用 1 × 1 1\times1 1×1的卷积层进行处理得到注意力 Y Y Y Y Y Y X X X进行像素级乘法得到最终的注意力特征图 Z Z Z

度量模块

度量模块首先将特征图使用双线性插值resize到和输入相同的大小,然后计算两个特征图之间像素级的欧氏距离图 D D D,在训练阶段,用其来计算损失值,在测试阶段使用一个固定的阈值方法进行分割。

3. 损失函数

文章设计了一个批量平衡对比损失(BCL),利用批次权重对原始对比损失的类权重进行修正,其定义如下:
L ( D ∗ , M ∗ ) = 1 2 1 n u ∑ b , i , j ( 1 − M b , i , j ∗ ) D b , i , j ∗ + 1 2 1 n c ∑ b , i , j M b , i , j ∗ Max ⁡ ( 0 , m − D b , i , j ∗ ) \begin{aligned}L\left(D^{*}, M^{*}\right) &=\frac{1}{2} \frac{1}{n_{u}} \sum_{b, i, j}\left(1-M_{b, i, j}^{*}\right) D_{b, i, j}^{*} \\&+\frac{1}{2} \frac{1}{n_{c}} \sum_{b, i, j} M_{b, i, j}^{*} \operatorname{Max}\left(0, m-D_{b, i, j}^{*}\right)\end{aligned} L(D,M)=21nu1b,i,j(1Mb,i,j)Db,i,j+21nc1b,i,jMb,i,jMax(0,mDb,i,j)
其中, M ∗ M^* M是二值标签图的一个批次, b , i , j b,i,j b,i,j表示批次的下标、高度、宽度。 m m m是margin, n u , n C n_u,n_C nu,nC是未变化和变化了的像素对的个数,其计算公式如下:
n u = ∑ b , i , j 1 − M b , i , j ∗ n c = ∑ b , i , j M b , i , j ∗ \begin{array}{c}n_{u}=\sum_{b, i, j} 1-M_{b, i, j}^{*} \\n_{c}=\sum_{b, i, j} M_{b, i, j}^{*}\end{array} nu=b,i,j1Mb,i,jnc=b,i,jMb,i,j

3. LEVIR-CD数据集

在这里插入图片描述

上图是生成的LEVIR-CD数据集的样例。

LEVIR-CD数据集的总体情况。

与其他数据集的对比

三、实验

实验使用的数据集有SZTAKI AirChange Benchmark Set (SZTAKI)、The Onera Satellite Change Detection dataset (OSCD)、The Aerial Imagery Change Detection dataset (AICD)以及LEVIR-CD数据集。

使用的评价指标是每一类的准确率、召回率和F1值。使用的baseline是FCN-Network、FCN-Network+BAM、FCN-Network+PAM。

上图是在LEVIR-CD数据集上的消融实验结果表。

在这里插入图片描述

在LEVIR-CD数据集上的结果。

在这里插入图片描述

上图是BCL损失函数在LEVIR-CD数据集上的消融实验结果表。

在这里插入图片描述
在这里插入图片描述

以上两图是不同方法在SZTAKI数据集上的结果对比。

在这里插入图片描述

注意力图的可视化结果。

### Pyramid Attention Module 实现与应用 #### 背景介绍 在深度学习领域,注意力机制被广泛应用于提升模型性能。金字塔注意力模块(PAM, Pyramid Attention Module)是一种特殊的注意力机制,在处理图像数据时表现出色。PAM通过多尺度特征融合来增强网络的感受野并提高特征表达能力。 #### PAM 的工作原理 PAM 主要由两个部分组成:空间维度上的全局上下文建模和通道维度的选择性加权[^1]。具体来说: - **空间金字塔池化层(SPP)**: 将输入特征图划分为不同大小的空间区域,并计算这些区域内像素的最大值或平均值作为该区域的表示; - **自适应推理模块(ARM)**: 对SPP得到的不同层次特征进行线性变换后相加以获得最终输出;此过程可以看作是在多个尺度上动态调整权重的过程。 这种设计使得PAM能够捕捉到更丰富的语义信息以及局部细节,从而有助于改善视觉识别任务的效果。 #### PyTorch 中实现 PAM 下面给出一段基于PyTorch框架下简单的PAM实现代码示例: ```python import torch.nn as nn class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super().__init__() self.conv = nn.Conv2d( 2, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) scale = torch.cat([avg_out, max_out], dim=1) scale = self.conv(scale) return x * torch.sigmoid(scale) class ChannelAttention(nn.Module): def __init__(self, channels, reduction_ratio=16): super().__init__() hidden_channels = int(channels / reduction_ratio) self.fc1 = nn.Linear(channels, hidden_channels, bias=False) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Linear(hidden_channels, channels, bias=False) def forward(self, x): b, c, _, _ = x.size() out_avg = F.adaptive_avg_pool2d(x, output_size=(1, 1)).view(b, c) out_max = F.adaptive_max_pool2d(x, output_size=(1, 1)).view(b, c) out_avg = self.fc2(self.relu(self.fc1(out_avg))) out_max = self.fc2(self.relu(self.fc1(out_max))) scale = torch.sigmoid(out_avg + out_max).unsqueeze(-1).unsqueeze(-1) return x * scale class PyramidAttentionModule(nn.Module): def __init__(self, in_channels, pyramid_sizes=[1, 2, 4]): super(PyramidAttentionModule, self).__init__() # 定义空间注意力建模组件 self.spatial_attention = SpatialAttention() # 构造多级联接结构用于提取多尺度特征 self.pyramids = nn.ModuleList([ nn.Sequential( nn.AdaptiveAvgPool2d(output_size=size), nn.Conv2d(in_channels=in_channels, out_channels=in_channels//len(pyramid_sizes), kernel_size=1), nn.BatchNorm2d(num_features=in_channels//len(pyramid_sizes)), nn.ReLU(), nn.Upsample(size=None, scale_factor=size, mode='bilinear', align_corners=True)) for size in pyramid_sizes]) # 维度转换操作以便后续连接 self.fusion_conv = nn.Conv2d(in_channels*2, in_channels, kernel_size=1) # 定义通道选择性加权组件 self.channel_attention = ChannelAttention(in_channels) def forward(self, input_tensor): spatial_attended_feature = self.spatial_attention(input_tensor) multi_scale_features = [] for layer in self.pyramids: feature_map = layer(spatial_attended_feature) multi_scale_features.append(feature_map) concatenated_feature_maps = torch.cat(multi_scale_features+[spatial_attended_feature], dim=1) fused_output = self.fusion_conv(concatenated_feature_maps) channel_weighted_fused_output = self.channel_attention(fused_output) return channel_weighted_fused_output ``` 上述代码实现了完整的PAM功能,包括了空间注意力机制、多尺度特征抽取以及通道间的信息交互。其中`SpatialAttention`负责构建空间域内的依赖关系,而`ChannelAttention`则专注于优化各个通道的重要性程度。最后两者的结果会被结合起来形成更加鲁棒性的表征向量。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值