(即插即用模块-Attention部分) 六十、(2024) SCSA 空间通道协同注意力

在这里插入图片描述

paper:SCSA: Exploring the Synergistic Effects Between Spatial and Channel Attention

Code:https://github.com/HZAI-ZJNU/SCSA


1、Spatial and Channel Synergistic Attention

通道和空间注意力分别为各种下游视觉任务的特征依赖性和空间结构关系提取带来了显著的改进。虽然两者的组合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,还缺乏充分利用多语义信息的协同潜力,所以,论文提出一种新的 空间通道协同注意力模块(SCSA)。SCSA由两部分组成:**Shared Multi-Semantic Spatial Attention(SMSA )**和 Progressive Channel-wise Self-Attention(PCSA )。其中,SMSA 用于集成多语义信息,并利用渐进压缩策略将区分性空间先验注入到PCSA的通道自注意力中,有效将通道重新校准。而 PCSA 中基于自注意力机制的鲁棒特征交互则进一步缓解了SMSA中不同子特征之间的多语义信息差异。具体来说:

  • SMSA: 采用多尺度 depth-wise 1D convolutions,分别从四个独立的子特征中提取不同语义层次的空间信息。并利用GroupNom来加速模型收敛,同时避免引入批量噪声以及子特征之间的语义信息泄漏的问题。
  • **PCSA:**通过结合渐进压缩和通道特定的自注意力机制(CSA),最大限度地减少计算复杂性,同时保留 SMSA 内的空间先验。此外,PCSA 利用自我注意力机制,进一步探索了通道层面的相似性,从而减少了不同子特征之间的语义差异。

SCSA 旨在相互补充。空间注意力从每个特征中提取多语义空间信息,为通道注意力计算提供精确的空间先验;而通道注意力通过利用整体特征图来细化局部子特征的语义理解,减轻SMSA中多尺度卷积引起的语义差异。

对于一个给定的输入X,SCCA 的实现过程:

SMSA:

  1. 首先沿着H,W维度进行分解,并分别进行全局平均池化,从而建立两个单向的一维序列结构。

  2. 然后为了学习不同的空间分布和上下文关系,将特征集划分为K个相同大小的独立子特征。并经由多尺度 depth-wise 1D convolutions 处理。通过多尺度机制来更高效地捕获每个子特征内的不同语义空间结构。

  3. 为了解决 特征分解 与 一维卷积 导致的有限感受域,在depth-wise 1D convolutions之后,使用了轻量级共享卷积进行特征对齐。

  4. 最后,聚合不同的语义子特征并使用 GroupNorm 进行归一化,然后使用 Sigmoid 激活函数并与X相乘生成空间注意力。


PCSA:

  1. 由 SMSA 处理后的特征 X_s 先经由平均池化,然后进行 GroupNorm 归一化处理。

  2. 通过使用多分支的DWConv生成 Q,K,V。并通过自注意力聚合。

  3. 最后通过一个平均池化层和sigmoid激活函数,并与 X_s 进行相乘生成注意力图。


SCSA 结构图:
在这里插入图片描述


2、代码实现

import typing as t

import torch
import torch.nn as nn
from einops import rearrange
from mmengine.model import BaseModule
__all__ = ['SCSA']


class SCSA(BaseModule):

    def __init__(
            self,
            dim: int,
            head_num: int,
            window_size: int = 7,
            group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
            qkv_bias: bool = False,
            fuse_bn: bool = False,
            norm_cfg: t.Dict = dict(type='BN'),
            act_cfg: t.Dict = dict(type='ReLU'),
            down_sample_mode: str = 'avg_pool',
            attn_drop_ratio: float = 0.,
            gate_layer: str = 'sigmoid',
    ):
        super(SCSA, self).__init__()
        self.dim = dim
        self.head_num = head_num
        self.head_dim = dim // head_num
        self.scaler = self.head_dim ** -0.5
        self.group_kernel_sizes = group_kernel_sizes
        self.window_size = window_size
        self.qkv_bias = qkv_bias
        self.fuse_bn = fuse_bn
        self.down_sample_mode = down_sample_mode

        assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
        self.group_chans = group_chans = self.dim // 4

        self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
                                   padding=group_kernel_sizes[0] // 2, groups=group_chans)
        self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
                                      padding=group_kernel_sizes[1] // 2, groups=group_chans)
        self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
                                      padding=group_kernel_sizes[2] // 2, groups=group_chans)
        self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
                                      padding=group_kernel_sizes[3] // 2, groups=group_chans)
        self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
        self.norm_h = nn.GroupNorm(4, dim)
        self.norm_w = nn.GroupNorm(4, dim)

        self.conv_d = nn.Identity()
        self.norm = nn.GroupNorm(1, dim)
        self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
        self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
        self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()

        if window_size == -1:
            self.down_func = nn.AdaptiveAvgPool2d((1, 1))
        else:
            if down_sample_mode == 'recombination':
                self.down_func = self.space_to_chans
                # dimensionality reduction
                self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
            elif down_sample_mode == 'avg_pool':
                self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
            elif down_sample_mode == 'max_pool':
                self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        The dim of x is (B, C, H, W)
        """
        # Spatial attention priority calculation
        b, c, h_, w_ = x.size()
        # (B, C, H)
        x_h = x.mean(dim=3)
        l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
        # (B, C, W)
        x_w = x.mean(dim=2)
        l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)

        x_h_attn = self.sa_gate(self.norm_h(torch.cat((
            self.local_dwc(l_x_h),
            self.global_dwc_s(g_x_h_s),
            self.global_dwc_m(g_x_h_m),
            self.global_dwc_l(g_x_h_l),
        ), dim=1)))
        x_h_attn = x_h_attn.view(b, c, h_, 1)

        x_w_attn = self.sa_gate(self.norm_w(torch.cat((
            self.local_dwc(l_x_w),
            self.global_dwc_s(g_x_w_s),
            self.global_dwc_m(g_x_w_m),
            self.global_dwc_l(g_x_w_l)
        ), dim=1)))
        x_w_attn = x_w_attn.view(b, c, 1, w_)

        x = x * x_h_attn * x_w_attn

        # Channel attention based on self attention
        # reduce calculations
        y = self.down_func(x)
        y = self.conv_d(y)
        _, _, h_, w_ = y.size()

        # normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
        y = self.norm(y)
        q = self.q(y)
        k = self.k(y)
        v = self.v(y)
        # (B, C, H, W) -> (B, head_num, head_dim, N)
        q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
                      head_dim=int(self.head_dim))
        k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
                      head_dim=int(self.head_dim))
        v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
                      head_dim=int(self.head_dim))

        # (B, head_num, head_dim, head_dim)
        attn = q @ k.transpose(-2, -1) * self.scaler
        attn = self.attn_drop(attn.softmax(dim=-1))
        # (B, head_num, head_dim, N)
        attn = attn @ v
        # (B, C, H_, W_)
        attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
        # (B, C, 1, 1)
        attn = attn.mean((2, 3), keepdim=True)
        attn = self.ca_gate(attn)
        return attn * x


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    model = SCSA(512, 1).cuda()
    out = model(x)
    print(out.shape)
### ResNet结合SCSA注意力机制的实现方式 #### SCSA模块概述 SCSA(Spatial and Channel Synergistic Attention)旨在探索空间通道注意力之间的协同效应。该模块通过融合空间维度上的特征图以及不同通道间的依赖关系来增强模型的表现力[^1]。 #### 实现细节 为了将SCSA集成到ResNet架构中,可以在每个残差单元内部引入此注意层。具体来说: - **输入处理**:接收来自前一层的标准卷积输出; - **并行分支设计**: - 创建两个独立路径分别用于计算空间级联响应与跨频道交互模式; - 对于空间部分采用轻量化操作如深度可分离卷积减少参数量;对于渠道则利用全局平均池化获取整体统计特性; - **多尺度建模**:考虑到图像内在结构可能存在多种比例尺变化情况,在上述基础上进一步加入金字塔形采样策略以捕捉更广泛范围内的上下文信息; - **最终融合**:经过各自变换后的两路信号再次汇聚并通过逐元素相乘完成强化表达式的构建过程。 ```python import torch.nn as nn class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) scale = torch.cat([avg_out, max_out], dim=1) scale = self.conv1(scale) return x * self.sigmoid(scale) class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) self.relu1 = nn.ReLU() self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) out = self.sigmoid(avg_out) return x * out class SCAModule(nn.Module): expansion = 4 def __init__(self, planes): super(SCAModule, self).__init__() self.spatial_attention = SpatialAttention() self.channel_attention = ChannelAttention(planes*self.expansion) def forward(self, x): residual = x spa_atten = self.spatial_attention(residual) cha_atten = self.channel_attention(spa_atten) out = spa_atten + cha_atten return out ``` #### 应用案例分析 在ImageNet-1K数据集上进行了广泛的实验验证,结果显示当把SCSA嵌入至预训练好的ResNet-50网络之后形成的新变体——命名为`SCSA-50`,不仅显著提升了分类准确性而且有效降低了过拟合风险。 此外,在目标检测领域也有成功实践。例如YOLOv8改进版本中也采用了类似的思路,通过对原有框架添加SCSA组件实现了性能优化[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值