paper:Dual Attention Network for Scene Segmentation
official implementation:https://github.com/junfu1115/DANet
third-party implementation:https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/da_head.py
背景
尽管基于全卷积网络(FCNs)的最新方法已经取得进展,但现有方法通常通过多尺度特征融合来捕获上下文信息,并不能充分利用全局视角中对象或材质之间的关系。此外,一些方法使用循环神经网络来捕获长距离依赖关系,但其有效性在很大程度上依赖于长期记忆的学习结果。
本文的创新点
本文提出了一种新的语义分割框架,称为双注意力网络(DANet),用于自然场景图像分割。该框架旨在通过自注意力机制分别捕获空间和通道维度上的特征依赖性,以解决现有方法在处理复杂和多样化场景时的局限性。
具体包括提出了位置注意力模块和通道注意力模块,分别对空间和通道维度上的语义特征进行建模,以自适应地整合局部特征与全局依赖性。在注意力模块中,通过自注意力机制,模型能够在空间维度上捕捉任何两个位置之间的语义相互依赖性,在通道维度上强调相互依赖的通道映射。
方法介绍
DANet的整体结构如图2所示,整体结构就是一个FCN,其中作者删除了ResNet后两个stage中的降采样,并采用了dilated convolution,从而将特征图扩大到输入的1/8。创新点就在于Position Attention Module和Channel Attention Module两个注意力模块,下面具体介绍。
Position Attention Module
如图3(A)所示,给定一个特征 \(\mathbf{A}\in \mathbb{R}^{C\times H\times W}\),首先通过一个卷积层得到两个新的特征图 \(\mathbf{B}\) 和 \(\mathbf{C}\),\(\{\mathbf{B},\mathbf{C}\}\in \mathbb{R}^{C\times H\times W}\)。然后将它们reshape成 \(\mathbb{R}^{C\times N}\),其中 \(N=H\times W\) 是像素的数量。然后将 \(\mathbf{B}\) 与 \(\mathbf{C}\) 的转置进行矩阵相乘,然后通过softmax来计算spatial attention map \(S\in\mathbb{R}^{N\times N}\)
其中 \(s_{ji}\) 描述的是 \(i^{th}\) 位置对 \(j^{th}\) 位置的影响,两者的特征表示越相似,它们之间的相关性就越大。
同时,特征 \(\mathbf{A}\) 通过另一个卷积层得到一个新的特征图 \(\mathbf{D}\in \mathbb{R}^{C\times H\times W}\) 并reshape成 \(\mathbb{R}^{C\times N}\)。然后将 \(\mathbf{D}\) 与 \(\mathbf{S}\) 的转置进行矩阵相乘。最后我们将其乘上一个尺度参数 \(\alpha\) 并与原始特征 \(\mathbf{A}\) 进行element-wise相加求和得到最终输出 \(\mathbf{E}\)
其中 \(\alpha\) 初始化为0并逐渐学习分配更多的权重。从式(2)可以看到,\(\mathbf{E}\) 中每个位置的结果是所有位置的特征和原始特征的加权和,所以它具有全局的上下文视图,并根据spatial attention map有选择地聚合上下文。相似的语义特征互相都得到了加强,从而增强了类内的紧凑性和语义的一致性。
Channel Attention Module
通道注意力模块的结构如图3(B)所示。与位置注意力模块不同,我们直接通过原始特征 \(\mathbf{A}\in \mathbb{R}^{C\times H\times W}\) 计算通道attention map \(\mathbf{X}\in \mathbb{R}^{C\times C}\)。具体来说,我们将 \(\mathbf{A}\) reshape成 \(\mathbb{R}^{C\times N}\),然后对 \(\mathbf{A}\) 和 \(\mathbf{A}\) 的转置进行矩阵相乘,最终通过softmax得到通道attention map \(\mathbf{X}\in \mathbb{R}^{C\times C}\)
其中 \(x_{ji}\) 描述了第 \(i^{th}\) 通道对第 \(j^{th}\) 通道的影响。此外,我们对 \(\mathbf{X}\) 的转置和 \(\mathbf{A}\) 进行矩阵相乘并将结果reshape成 \(\mathbb{R}^{C\times H\times W}\),然后乘以一个缩放系数 \(\beta\) 并和 \(\mathbf{A}\) 进行一个element-wise相加得到最终结果 \(\mathbf{E}\in \mathbb{R}^{C\times H\times W}\)
其中 \(\beta\) 从0逐渐学习一个权重。公式(4)表明每个通道的最终特征是所有通道的特征和原始特征的加权和,它建模了特征图之间的long-range semantic dependecies,有助于提高特征的辨别性。
代码解析
面的代码是MMSegmentation中的实现,首先是PAM,position attention module调用SelfAttentionBlock和OCRNet中调用的是同一个,关于OCRNet的介绍见OCRNet原理与代码解析(ECCV 2020)-CSDN博客,这里不再过多解释。
class PAM(_SelfAttentionBlock):
"""Position Attention Module (PAM)
Args:
in_channels (int): Input channels of key/query feature.
channels (int): Output channels of key/query transform.
"""
def __init__(self, in_channels, channels):
super().__init__(
key_in_channels=in_channels,
query_in_channels=in_channels,
channels=channels,
out_channels=in_channels,
share_key_query=False,
query_downsample=None,
key_downsample=None,
key_query_num_convs=1,
key_query_norm=False,
value_out_num_convs=1,
value_out_norm=False,
matmul_norm=False,
with_out=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None)
self.gamma = Scale(0)
def forward(self, x):
"""Forward function."""
out = super().forward(x, x)
out = self.gamma(out) + x
return out
然后是,CAM,channel attention module,实现如下
class CAM(nn.Module):
"""Channel Attention Module (CAM)"""
def __init__(self):
super().__init__()
self.gamma = Scale(0)
def forward(self, x): # (8,512,60,60)
"""Forward function."""
batch_size, channels, height, width = x.size()
proj_query = x.view(batch_size, channels, -1) # (8,512,3600)
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1) # (8,3600,512)
energy = torch.bmm(proj_query, proj_key) # (8,512,512), 每个通道hw的特征图和所有其它通道的特征图求点积
energy_new = torch.max(
energy, -1, keepdim=True)[0].expand_as(energy) - energy # (8,512,1)->(8,512,512) - (8,512,512), 文章中好像没有这一步,每个注意力都减去最大注意力值
attention = F.softmax(energy_new, dim=-1) # (8,512,512), 每个通道和所有其它通道注意力的和为1,值越大说明这两个通道之间的关联性越大
proj_value = x.view(batch_size, channels, -1) # (8,512,3600)
out = torch.bmm(attention, proj_value) # (8,512,3600)
# attention是512x512的每一行和为1,表示一个通道和其它所有通道的注意力(相关性)
# 以attention的第一行与proj_value的第一列求点积得到out第一个点[0,0]的值为例,表示proj_value中所有通道[0,0]点的值根据该通道与其它所有通道的注意力加权求和。
# 输出out中某个通道某个像素点的值等于proj_value中这个点所有通道的值根据这个通道与其它所有通道的注意力进行加权求和得到。
out = out.view(batch_size, channels, height, width) # (8,512,60,60)
out = self.gamma(out) + x
return out