YOLO即插即用模块---ASSA

Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration

论文地址:

Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restorationicon-default.png?t=O83Ahttps://openaccess.thecvf.com/content/CVPR2024/papers/Zhou_Adapt_or_Perish_Adaptive_Sparse_Transformer_with_Attentive_Feature_Refinement_CVPR_2024_paper.pdf

主要问题: 基于Transformer的图像恢复方法在建模长距离依赖关系方面表现出色,但同时也存在计算量大、冗余信息多和噪声交互等问题。

解决方案: 论文提出了一个自适应稀疏Transformer (AST) 模型,旨在减少无关区域的噪声交互,并消除空间和通道域中的特征冗余。

AST模型主要包含两个核心设计

  • 自适应稀疏自注意力 (ASSA) 模块: 该模块采用双分支模式,包括稀疏自注意力分支 (SSA) 和密集自注意力分支 (DSA)。SSA用于过滤掉低查询-键匹配分数的负面影響,而DSA则确保足够的信息流通过网络网络,以学习判别性表示。

  • 特征细化前馈网络 (FRFN): 该模块采用增强和简化方案来 适用任务: 论 AST模型在多个图像恢复任务中表现出色,包括:

  • 雨痕去除: 在SPAD数据集上,AST-B模型在PSNR指标上比现有的最佳的CNN模型和Transformer模型。

  • 雨滴去除: 在AGAN-Data数据集上,AST-B模型在PSNR指标上优于之前最佳的雨滴去除方法和。

  • 真实雾去除: 在Dense-Haze数据集上,AST-B模型在PSNR指标上优于之前最佳的雾去除方法。

 AST模型通过自适应稀疏自注意力和特征细化前馈网络,有效地解决了基于  AST模型通过自适应稀疏自注意力和特征细化前馈网络,有效地解决了基于  AST模型通过自适应稀疏自注意力和特征细化前馈网络,有效地解决了基于Transformer的图像恢复方法中存在的计算量大、冗余信息多和噪声交互等问题,并在多个图像恢复任务中取得了优异的性能。

即插即用代码:

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from einops import repeat

class LinearProjection(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, bias=True):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.to_q = nn.Linear(dim, inner_dim, bias = bias)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
        self.dim = dim
        self.inner_dim = inner_dim

    def forward(self, x, attn_kv=None):
        B_, N, C = x.shape
        if attn_kv is not None:
            attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1)
        else:
            attn_kv = x
        N_kv = attn_kv.size(1)
        q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
        kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
        q = q[0]
        k, v = kv[0], kv[1]
        return q,k,v

# Adaptive Sparse Self-Attention (ASSA)
class WindowAttention_sparse(nn.Module):
    def __init__(self, dim, win_size, num_heads=8, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
                 proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.win_size = win_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.win_size[0])  # [0,...,Wh-1]
        coords_w = torch.arange(self.win_size[1])  # [0,...,Ww-1]
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.win_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.win_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
        trunc_normal_(self.relative_position_bias_table, std=.02)

        if token_projection == 'linear':
            self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
        else:
            raise Exception("Projection error!")

        self.token_projection = token_projection
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        self.w = nn.Parameter(torch.ones(2))

    def forward(self, x, attn_kv=None, mask=None):
        B_, N, C = x.shape
        q, k, v = self.qkv(x, attn_kv)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        ratio = attn.size(-1) // relative_position_bias.size(-1)
        relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)

        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N * ratio)
            attn0 = self.softmax(attn)
            attn1 = self.relu(attn) ** 2  # b,h,w,c
        else:
            attn0 = self.softmax(attn)
            attn1 = self.relu(attn) ** 2
        w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
        w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
        attn = attn0 * w1 + attn1 * w2
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


if __name__ == '__main__':
    # Instantiate the WindowAttention_sparse class
    dim = 32  # Dimension of input features
    win_size = (32, 32)  # Window size(H, W)
    # Create an instance of the WindowAttention_sparse module
    window_attention_sparse = WindowAttention_sparse(dim, win_size)
    C = dim
    input = torch.randn(1, 32 * 32, C)#输入B H W
    # Forward pass
    output = window_attention_sparse(input)

    # Print input and output size
    print(input.size())
    print(output.size())

玩yolo的同行可进交流群,群里有答疑(QQ:828370883):

### 自适应稀疏自注意力机制 (ASSA) 的介绍 自适应稀疏自注意力机制(Adaptive Sparse Self-Attention, ASSA)是一种创新性的注意力机制,旨在优化传统自注意力机制中存在的冗余信息问题。通过引入稀疏分支和密集分支的设计,ASSA 能够有效减少不相关区域的噪声相互作用,并消除空间域和通道域中的特征冗余[^2]。 ### 原理 ASSA 使用双分支范式来自适应地计算注意力权重: - **稀疏分支**:该分支专注于识别并抑制那些具有较低查询键匹配分数的特征,从而防止这些低质量特征对最终聚合结果造成负面影响。 - **密集分支**:此部分确保足够的信息流在网络中传递,使得模型能够学习到更具区分度的表示形式。这两个分支共同工作,在不同场景下动态调整其贡献比例,以实现最佳性能。 此外,为了进一步提升效果,ASSA 还结合了一个称为功能精炼前馈网络(Feature Refinement Feedforward Network, FRFN)组件。FRFN 通过对通道内的特征进行增强和简化操作来降低特征密度,进而改善整体表现[^4]。 ### 实现 以下是基于 PyTorch 框架的一个简单示例代码片段展示如何构建一个基本版本的 AST 结构,其中包括 ASSA 和 FRFN 组件: ```python import torch.nn as nn import torch class ASSABlock(nn.Module): def __init__(self, d_model=512, nhead=8): super().__init__() self.sparse_attn = nn.MultiheadAttention(d_model, nhead) self.dense_attn = nn.MultiheadAttention(d_model, nhead) def forward(self, query, key, value): sparse_output, _ = self.sparse_attn(query, key, value) dense_output, _ = self.dense_attn(query, key, value) output = torch.cat((sparse_output, dense_output), dim=-1) return output class FRFNNetwork(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out class AdaptiveSparseTransformer(nn.Module): def __init__(self, d_model=512, nhead=8, num_layers=6): super().__init__() encoder_layer = nn.TransformerEncoderLayer(d_model=d_model * 2, nhead=nhead*2) transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.assa_block = ASSABlock(d_model, nhead) self.transformer = transformer_encoder self.frfn_network = FRFNNetwork(d_model * 2, d_model, d_model) def forward(self, src): assa_out = self.assa_block(src, src, src) frfn_out = self.frfn_network(assa_out) final_out = self.transformer(frfn_out) return final_out ``` 这段代码定义了三个主要模块——`ASSABlock`, `FRFNNetwork` 及整个架构的核心类 `AdaptiveSparseTransformer`. 它们协同工作实现了 ASSA 所描述的功能特性. ### 应用 ASSA 已经被成功应用于多种领域内的重要任务当中: - 在目标检测方面,YOLOv11 中加入了 ASSA 改进了原有模型对于复杂背景下的物体定位精度[^1]. - 对于时间序列预测而言,LSTM 加上 Transformer 并融入 ASSA 后显著提高了对未来趋势变化预估的能力. - 图像修复任务也受益匪浅; Adapt or Perish 论文中提到的方法不仅减少了计算资源消耗还提升了去噪、除雾等多个子任务的表现.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值