(即插即用模块-Attention部分) 四十三、(TIM 2023) LSGA 轻量自高斯注意力

在这里插入图片描述

paper:Light Self-Gaussian-Attention Vision Transformer for Hyperspectral Image Classification

Code:https://github.com/machao132/LSGA-VIT


1、Light self-Gaussian-Attention

在现有的研究中存在以下两个问题,即 Transformer 计算成本高: 传统 Transformer 使用 QKV (Query-Key-Value) 结构进行自注意力计算,需要大量参数和计算量,导致模型复杂度高,运行效率低。位置信息缺失: Transformer 无法直接获取 token 的位置信息,这可能导致模型在处理图像时无法区分中心像素和周围像素的特征,影响分类精度。这篇论文提出一种 轻量自高斯注意力(Light Self-Gaussian-Attention),其是一种轻量级的自注意力机制,用于改进 Transformer 在图像处理任务中的性能。

LSGA 的主要思想是 高斯绝对位置偏差,LSGA 使用高斯函数模拟高光谱图像中中心像素和周围像素的特征关系,并引入高斯绝对位置偏差,使注意力权重更偏向中心像素,从而增强中心像素特征的表达,并抑制周围像素的干扰。

对于输入X,LSGA 的实现过程:

  1. 输入矩阵 X:假设输入矩阵 X 的维度为 (t, c’),其中 t 表示 token 的数量,c’ 表示 token 的维度。
  2. Query 矩阵:将 X 作为 Query 矩阵,即 Q = X。
  3. Key 矩阵和 Value 矩阵:将 X 作为 Key 矩阵和 Value 矩阵,即 K = X, V = X。
  4. 计算注意力权重:计算 Q 和 K 的点积,得到 QKT 矩阵。对 QKT 矩阵进行 Softmax 操作,得到注意力权重矩阵 D。
  5. 计算注意力输出:将注意力权重矩阵 D 与 Value 矩阵 V 相乘,得到注意力输出矩阵 Y = DV。
  6. 线性投影:将注意力输出矩阵 Y 通过线性投影层,得到最终输出矩阵 Z。

Light self-Gaussian-Attention 结构图:
在这里插入图片描述

2、LSGA-ViT

基于 LSGA,这篇论文还提出一种 LSGA-ViT,LSGA-ViT 是一种基于轻量级自高斯注意力 (LSGA) 机制的视 Transformer 模型,用于高光谱图像分类任务。其结合了 CNN 和 Transformer 的优势,有效地提取空间-光谱特征,并通过长距离建模能力进行全局特征提取,从而提高分类精度。LSGA-ViT 主要由以下几个模块构成:

  • 混合空间-光谱 tokenization 模块:使用 3D 卷积层提取空间-光谱特征。以及使用 2D 卷积层进一步提取特征,并调整 token 的维度。
  • LSGA Transformer 块:使用 LSA 进行自注意力计算,并引入高斯绝对位置偏差。然后使用 MLP 层进一步处理 token。
  • 分类器:将最终 token 送入线性层进行分类。

LSGA-ViT 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
from einops import rearrange, repeat


class LSGAttention(nn.Module):
    def __init__(self, dim, att_inputsize, num_heads=4, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.att_inputsize = att_inputsize[0]
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        totalpixel = self.att_inputsize * self.att_inputsize
        gauss_coords_h = torch.arange(totalpixel) - int((totalpixel - 1) / 2)
        gauss_coords_w = torch.arange(totalpixel) - int((totalpixel - 1) / 2)
        gauss_x, gauss_y = torch.meshgrid([gauss_coords_h, gauss_coords_w])
        sigma = 10
        gauss_pos_index = torch.exp(torch.true_divide(-(gauss_x ** 2 + gauss_y ** 2), (2 * sigma ** 2)))
        self.register_buffer("gauss_pos_index", gauss_pos_index)
        self.token_wA = nn.Parameter(torch.empty(1, self.att_inputsize * self.att_inputsize, dim),
                                     requires_grad=True)  # Tokenization parameters
        torch.nn.init.xavier_normal_(self.token_wA)
        self.token_wV = nn.Parameter(torch.empty(1, dim, dim),
                                     requires_grad=True)  # Tokenization parameters
        torch.nn.init.xavier_normal_(self.token_wV)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        wa = repeat(self.token_wA, '() n d -> b n d', b=B_)  # wa (bs 4 64)
        wa = rearrange(wa, 'b h w -> b w h')  # Transpose # wa (bs 64 4)
        A = torch.einsum('bij,bjk->bik', x, wa)  # A (bs 81 4)
        A = rearrange(A, 'b h w -> b w h')  # Transpose # A (bs 4 81)
        A = A.softmax(dim=-1)
        VV = repeat(self.token_wV, '() n d -> b n d', b=B_)  # VV(bs,64,64)
        VV = torch.einsum('bij,bjk->bik', x, VV)  # VV(bs,81,64)
        x = torch.einsum('bij,bjk->bik', A, VV)  # T(bs,4,64)
        absolute_pos_bias = self.gauss_pos_index.unsqueeze(0)
        q = self.qkv(x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        attn = attn + absolute_pos_bias.unsqueeze(0)
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


if __name__ == '__main__':
    H, W = 7, 7
    x = torch.randn(4, 512, 7, 7)
    x = rearrange(x, 'b c h w -> b (h w) c')
    model = LSGAttention(512, (7, 7))
    output = model(x)
    output = rearrange(output, 'b (h w) c -> b c h w', h=H, w=W)
    print(output.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值