paper:SCSA: Exploring the Synergistic Effects Between Spatial and Channel Attention
1、Spatial and Channel Synergistic Attention
通道和空间注意力分别为各种下游视觉任务的特征依赖性和空间结构关系提取带来了显著的改进。虽然两者的组合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,还缺乏充分利用多语义信息的协同潜力,所以,论文提出一种新的 空间通道协同注意力模块(SCSA)。SCSA由两部分组成:**Shared Multi-Semantic Spatial Attention(SMSA )**和 Progressive Channel-wise Self-Attention(PCSA )。其中,SMSA 用于集成多语义信息,并利用渐进压缩策略将区分性空间先验注入到PCSA的通道自注意力中,有效将通道重新校准。而 PCSA 中基于自注意力机制的鲁棒特征交互则进一步缓解了SMSA中不同子特征之间的多语义信息差异。具体来说:
- SMSA: 采用多尺度 depth-wise 1D convolutions,分别从四个独立的子特征中提取不同语义层次的空间信息。并利用GroupNom来加速模型收敛,同时避免引入批量噪声以及子特征之间的语义信息泄漏的问题。
- **PCSA:**通过结合渐进压缩和通道特定的自注意力机制(CSA),最大限度地减少计算复杂性,同时保留 SMSA 内的空间先验。此外,PCSA 利用自我注意力机制,进一步探索了通道层面的相似性,从而减少了不同子特征之间的语义差异。
SCSA 旨在相互补充。空间注意力从每个特征中提取多语义空间信息,为通道注意力计算提供精确的空间先验;而通道注意力通过利用整体特征图来细化局部子特征的语义理解,减轻SMSA中多尺度卷积引起的语义差异。
对于一个给定的输入X,SCCA 的实现过程:
SMSA:
-
首先沿着H,W维度进行分解,并分别进行全局平均池化,从而建立两个单向的一维序列结构。
-
然后为了学习不同的空间分布和上下文关系,将特征集划分为K个相同大小的独立子特征。并经由多尺度 depth-wise 1D convolutions 处理。通过多尺度机制来更高效地捕获每个子特征内的不同语义空间结构。
-
为了解决 特征分解 与 一维卷积 导致的有限感受域,在depth-wise 1D convolutions之后,使用了轻量级共享卷积进行特征对齐。
-
最后,聚合不同的语义子特征并使用 GroupNorm 进行归一化,然后使用 Sigmoid 激活函数并与X相乘生成空间注意力。
PCSA:
-
由 SMSA 处理后的特征 X_s 先经由平均池化,然后进行 GroupNorm 归一化处理。
-
通过使用多分支的DWConv生成 Q,K,V。并通过自注意力聚合。
-
最后通过一个平均池化层和sigmoid激活函数,并与 X_s 进行相乘生成注意力图。
SCSA 结构图:
2、代码实现
import typing as t
import torch
import torch.nn as nn
from einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']
class SCSA(BaseModule):
def __init__(
self,
dim: int,
head_num: int,
window_size: int = 7,
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
qkv_bias: bool = False,
fuse_bn: bool = False,
norm_cfg: t.Dict = dict(type='BN'),
act_cfg: t.Dict = dict(type='ReLU'),
down_sample_mode: str = 'avg_pool',
attn_drop_ratio: float = 0.,
gate_layer: str = 'sigmoid',
):
super(SCSA, self).__init__()
self.dim = dim
self.head_num = head_num
self.head_dim = dim // head_num
self.scaler = self.head_dim ** -0.5
self.group_kernel_sizes = group_kernel_sizes
self.window_size = window_size
self.qkv_bias = qkv_bias
self.fuse_bn = fuse_bn
self.down_sample_mode = down_sample_mode
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
self.group_chans = group_chans = self.dim // 4
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
padding=group_kernel_sizes[0] // 2, groups=group_chans)
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
padding=group_kernel_sizes[1] // 2, groups=group_chans)
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
padding=group_kernel_sizes[2] // 2, groups=group_chans)
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
padding=group_kernel_sizes[3] // 2, groups=group_chans)
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
self.norm_h = nn.GroupNorm(4, dim)
self.norm_w = nn.GroupNorm(4, dim)
self.conv_d = nn.Identity()
self.norm = nn.GroupNorm(1, dim)
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
self.attn_drop = nn.Dropout(attn_drop_ratio)
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
if window_size == -1:
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
else:
if down_sample_mode == 'recombination':
self.down_func = self.space_to_chans
# dimensionality reduction
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
elif down_sample_mode == 'avg_pool':
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
elif down_sample_mode == 'max_pool':
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
The dim of x is (B, C, H, W)
"""
# Spatial attention priority calculation
b, c, h_, w_ = x.size()
# (B, C, H)
x_h = x.mean(dim=3)
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
# (B, C, W)
x_w = x.mean(dim=2)
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
self.local_dwc(l_x_h),
self.global_dwc_s(g_x_h_s),
self.global_dwc_m(g_x_h_m),
self.global_dwc_l(g_x_h_l),
), dim=1)))
x_h_attn = x_h_attn.view(b, c, h_, 1)
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
self.local_dwc(l_x_w),
self.global_dwc_s(g_x_w_s),
self.global_dwc_m(g_x_w_m),
self.global_dwc_l(g_x_w_l)
), dim=1)))
x_w_attn = x_w_attn.view(b, c, 1, w_)
x = x * x_h_attn * x_w_attn
# Channel attention based on self attention
# reduce calculations
y = self.down_func(x)
y = self.conv_d(y)
_, _, h_, w_ = y.size()
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
# (B, C, H, W) -> (B, head_num, head_dim, N)
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
head_dim=int(self.head_dim))
# (B, head_num, head_dim, head_dim)
attn = q @ k.transpose(-2, -1) * self.scaler
attn = self.attn_drop(attn.softmax(dim=-1))
# (B, head_num, head_dim, N)
attn = attn @ v
# (B, C, H_, W_)
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
# (B, C, 1, 1)
attn = attn.mean((2, 3), keepdim=True)
attn = self.ca_gate(attn)
return attn * x
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7).cuda()
model = SCSA(512, 1).cuda()
out = model(x)
print(out.shape)