paper:ESSAformer: Efficient Transformer for Hyperspectral Image Super-resolution
1、Efficient SCC-kernel-based Self-Attention
在为高光谱图像超分辨率任务设计的自注意力机制研究中,通常包含以下几点不足:高光谱图像的特性: 高光谱图像包含丰富的光谱和空间信息,而传统自注意力机制无法有效利用这些信息。计算复杂度问题: 传统自注意力机制的计算复杂度为 O(N^2),其中 N 为序列长度,这在高分辨率高光谱图像中会导致巨大的计算负担。数据效率问题: 高光谱图像数据获取困难,难以构建大规模训练数据集,导致模型训练效率低下。为了有效缓解这些问题,这篇论文提出一种 高效SCC核自注意力(Spectral Correlation Coefficient of Spectrum-kernel-based Self-Attention),其旨在解决传统自注意力机制在高光谱图像中存在的计算复杂度高、数据效率低的问题。
ESSA 通过利用光谱相关系数 (SCC) 作为相似性度量,并结合核技巧来实现整个过程。其中,SCC 是一种鲁棒的光谱相似性度量,它能够有效地衡量两个光谱曲线之间的相关性,并具有平移不变性和缩放不变性,使其对阴影和遮挡等干扰因素不敏感。此外, ESSA 将 SCC 与高斯径向基函数 (RBF) 核相结合,利用 Mercer 定理证明了 SCC 核的存在性,并通过泰勒展开将其转换为线性计算复杂度的形式。
对于输入X,ESSA 的实现过程:
- 特征嵌入: 将输入特征图投影为三个序列,即查询 (Q)、键 (K) 和值 (V)。
- SCC 计算: 计算查询序列和键序列之间的 SCC 值,并将其平方后作为注意力矩阵。
- 核函数应用: 将 SCC 核应用于查询序列和键序列,得到新的序列。
- 注意力计算: 计算查询序列与键序列之间的点积,并应用 softmax 函数作为权重,对值序列进行加权求和,得到最终的输出序列。
Efficient SCC-kernel-based Self-Attention 结构图:
2、ESSAformer
在ESSA的基础上,论文还提出一种基于 Transformer 的网络架构 ESSAformer ,专门用于高光谱图像超分辨率 (HSI-SR) 任务。其结合了 ESSA 自注意力机制和迭代细化结构,能够有效地生成高分辨率高光谱图像。
ESSAformer 的基本结构如下:
- 特征提取: 使用投影层将原始高光谱图像转换为特征图。
- 迭代细化: 包含多个阶段,每个阶段包括:重缩放模块: 对特征图进行上采样或下采样,以捕获不同尺度的特征。编码器层: 包含 ESSA 自注意力模块和前馈网络 (FFN) 模块,用于编码特征信息。残差连接: 将相邻三个阶段的特征图进行残差连接,以融合多尺度特征信息。
- 输出: 使用卷积层将特征图投影回原始通道维度,得到高分辨率高光谱图像。
ESSAformer 结构图:
3、代码实现
import torch
import torch.nn as nn
import math
from einops.einops import rearrange
class ESSAttn(nn.Module):
def __init__(self, dim):
super().__init__()
self.lnqkv = nn.Linear(dim, dim * 3)
self.ln = nn.Linear(dim, dim)
def forward(self, x):
b, N, C = x.shape
qkv = self.lnqkv(x)
qkv = torch.split(qkv, C, 2)
q, k, v = qkv[0], qkv[1], qkv[2]
a = torch.mean(q, dim=2, keepdim=True)
q = q - a
a = torch.mean(k, dim=2, keepdim=True)
k = k - a
q2 = torch.pow(q, 2)
q2s = torch.sum(q2, dim=2, keepdim=True)
k2 = torch.pow(k, 2)
k2s = torch.sum(k2, dim=2, keepdim=True)
t1 = v
k2 = torch.nn.functional.normalize((k2 / (k2s + 1e-7)), dim=-2)
q2 = torch.nn.functional.normalize((q2 / (q2s + 1e-7)), dim=-1)
t2 = q2 @ (k2.transpose(-2, -1) @ v) / math.sqrt(N)
attn = t1 + t2
attn = self.ln(attn)
return attn
def is_same_matrix(self, m1, m2):
rows, cols = len(m1), len(m1[0])
for i in range(rows):
for j in range(cols):
if m1[i][j] != m2[i][j]:
return False
return True
if __name__ == '__main__':
H, W = 7, 7
x = torch.randn(4, 512, 7, 7)
x = rearrange(x, 'b c h w -> b (h w) c')
model = ESSAttn(512)
output = model(x)
output = rearrange(output, 'b (h w) c -> b c h w', h=H, w=W)
print(output.shape)