(即插即用模块-Attention部分) 五十四、(ICCV 2023) ESSA 高效SCC核自注意力

在这里插入图片描述

paper:ESSAformer: Efficient Transformer for Hyperspectral Image Super-resolution

Code:https://github.com/rexzhan/essaformer


1、Efficient SCC-kernel-based Self-Attention

在为高光谱图像超分辨率任务设计的自注意力机制研究中,通常包含以下几点不足:高光谱图像的特性: 高光谱图像包含丰富的光谱和空间信息,而传统自注意力机制无法有效利用这些信息。计算复杂度问题: 传统自注意力机制的计算复杂度为 O(N^2),其中 N 为序列长度,这在高分辨率高光谱图像中会导致巨大的计算负担。数据效率问题: 高光谱图像数据获取困难,难以构建大规模训练数据集,导致模型训练效率低下。为了有效缓解这些问题,这篇论文提出一种 高效SCC核自注意力(Spectral Correlation Coefficient of Spectrum-kernel-based Self-Attention),其旨在解决传统自注意力机制在高光谱图像中存在的计算复杂度高、数据效率低的问题。

ESSA 通过利用光谱相关系数 (SCC) 作为相似性度量,并结合核技巧来实现整个过程。其中,SCC 是一种鲁棒的光谱相似性度量,它能够有效地衡量两个光谱曲线之间的相关性,并具有平移不变性和缩放不变性,使其对阴影和遮挡等干扰因素不敏感。此外, ESSA 将 SCC 与高斯径向基函数 (RBF) 核相结合,利用 Mercer 定理证明了 SCC 核的存在性,并通过泰勒展开将其转换为线性计算复杂度的形式。

对于输入X,ESSA 的实现过程:

  1. 特征嵌入: 将输入特征图投影为三个序列,即查询 (Q)、键 (K) 和值 (V)。
  2. SCC 计算: 计算查询序列和键序列之间的 SCC 值,并将其平方后作为注意力矩阵。
  3. 核函数应用: 将 SCC 核应用于查询序列和键序列,得到新的序列。
  4. 注意力计算: 计算查询序列与键序列之间的点积,并应用 softmax 函数作为权重,对值序列进行加权求和,得到最终的输出序列。

Efficient SCC-kernel-based Self-Attention 结构图:
在这里插入图片描述


2、ESSAformer

在ESSA的基础上,论文还提出一种基于 Transformer 的网络架构 ESSAformer ,专门用于高光谱图像超分辨率 (HSI-SR) 任务。其结合了 ESSA 自注意力机制和迭代细化结构,能够有效地生成高分辨率高光谱图像。

ESSAformer 的基本结构如下:

  • 特征提取: 使用投影层将原始高光谱图像转换为特征图。
  • 迭代细化: 包含多个阶段,每个阶段包括:重缩放模块: 对特征图进行上采样或下采样,以捕获不同尺度的特征。编码器层: 包含 ESSA 自注意力模块和前馈网络 (FFN) 模块,用于编码特征信息。残差连接: 将相邻三个阶段的特征图进行残差连接,以融合多尺度特征信息。
  • 输出: 使用卷积层将特征图投影回原始通道维度,得到高分辨率高光谱图像。

ESSAformer 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
import math
from einops.einops import rearrange


class ESSAttn(nn.Module):

    def __init__(self, dim):
        super().__init__()
        self.lnqkv = nn.Linear(dim, dim * 3)
        self.ln = nn.Linear(dim, dim)

    def forward(self, x):
        b, N, C = x.shape
        qkv = self.lnqkv(x)
        qkv = torch.split(qkv, C, 2)
        q, k, v = qkv[0], qkv[1], qkv[2]
        a = torch.mean(q, dim=2, keepdim=True)
        q = q - a
        a = torch.mean(k, dim=2, keepdim=True)
        k = k - a
        q2 = torch.pow(q, 2)
        q2s = torch.sum(q2, dim=2, keepdim=True)
        k2 = torch.pow(k, 2)
        k2s = torch.sum(k2, dim=2, keepdim=True)
        t1 = v
        k2 = torch.nn.functional.normalize((k2 / (k2s + 1e-7)), dim=-2)
        q2 = torch.nn.functional.normalize((q2 / (q2s + 1e-7)), dim=-1)
        t2 = q2 @ (k2.transpose(-2, -1) @ v) / math.sqrt(N)
        attn = t1 + t2
        attn = self.ln(attn)
        return attn

    def is_same_matrix(self, m1, m2):
        rows, cols = len(m1), len(m1[0])
        for i in range(rows):
            for j in range(cols):
                if m1[i][j] != m2[i][j]:
                    return False
        return True


if __name__ == '__main__':
    H, W = 7, 7
    x = torch.randn(4, 512, 7, 7)
    x = rearrange(x, 'b c h w -> b (h w) c')
    model = ESSAttn(512)
    output = model(x)
    output = rearrange(output, 'b (h w) c -> b c h w', h=H, w=W)
    print(output.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值