paper:Light Self-Gaussian-Attention Vision Transformer for Hyperspectral Image Classification
1、Light self-Gaussian-Attention
在现有的研究中存在以下两个问题,即 Transformer 计算成本高: 传统 Transformer 使用 QKV (Query-Key-Value) 结构进行自注意力计算,需要大量参数和计算量,导致模型复杂度高,运行效率低。位置信息缺失: Transformer 无法直接获取 token 的位置信息,这可能导致模型在处理图像时无法区分中心像素和周围像素的特征,影响分类精度。这篇论文提出一种 轻量自高斯注意力(Light Self-Gaussian-Attention),其是一种轻量级的自注意力机制,用于改进 Transformer 在图像处理任务中的性能。
LSGA 的主要思想是 高斯绝对位置偏差,LSGA 使用高斯函数模拟高光谱图像中中心像素和周围像素的特征关系,并引入高斯绝对位置偏差,使注意力权重更偏向中心像素,从而增强中心像素特征的表达,并抑制周围像素的干扰。
对于输入X,LSGA 的实现过程:
- 输入矩阵 X:假设输入矩阵 X 的维度为 (t, c’),其中 t 表示 token 的数量,c’ 表示 token 的维度。
- Query 矩阵:将 X 作为 Query 矩阵,即 Q = X。
- Key 矩阵和 Value 矩阵:将 X 作为 Key 矩阵和 Value 矩阵,即 K = X, V = X。
- 计算注意力权重:计算 Q 和 K 的点积,得到 QKT 矩阵。对 QKT 矩阵进行 Softmax 操作,得到注意力权重矩阵 D。
- 计算注意力输出:将注意力权重矩阵 D 与 Value 矩阵 V 相乘,得到注意力输出矩阵 Y = DV。
- 线性投影:将注意力输出矩阵 Y 通过线性投影层,得到最终输出矩阵 Z。
Light self-Gaussian-Attention 结构图:
2、LSGA-ViT
基于 LSGA,这篇论文还提出一种 LSGA-ViT,LSGA-ViT 是一种基于轻量级自高斯注意力 (LSGA) 机制的视 Transformer 模型,用于高光谱图像分类任务。其结合了 CNN 和 Transformer 的优势,有效地提取空间-光谱特征,并通过长距离建模能力进行全局特征提取,从而提高分类精度。LSGA-ViT 主要由以下几个模块构成:
- 混合空间-光谱 tokenization 模块:使用 3D 卷积层提取空间-光谱特征。以及使用 2D 卷积层进一步提取特征,并调整 token 的维度。
- LSGA Transformer 块:使用 LSA 进行自注意力计算,并引入高斯绝对位置偏差。然后使用 MLP 层进一步处理 token。
- 分类器:将最终 token 送入线性层进行分类。
LSGA-ViT 结构图:
3、代码实现
import torch
import torch.nn as nn
from einops import rearrange, repeat
class LSGAttention(nn.Module):
def __init__(self, dim, att_inputsize, num_heads=4, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.att_inputsize = att_inputsize[0]
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
totalpixel = self.att_inputsize * self.att_inputsize
gauss_coords_h = torch.arange(totalpixel) - int((totalpixel - 1) / 2)
gauss_coords_w = torch.arange(totalpixel) - int((totalpixel - 1) / 2)
gauss_x, gauss_y = torch.meshgrid([gauss_coords_h, gauss_coords_w])
sigma = 10
gauss_pos_index = torch.exp(torch.true_divide(-(gauss_x ** 2 + gauss_y ** 2), (2 * sigma ** 2)))
self.register_buffer("gauss_pos_index", gauss_pos_index)
self.token_wA = nn.Parameter(torch.empty(1, self.att_inputsize * self.att_inputsize, dim),
requires_grad=True) # Tokenization parameters
torch.nn.init.xavier_normal_(self.token_wA)
self.token_wV = nn.Parameter(torch.empty(1, dim, dim),
requires_grad=True) # Tokenization parameters
torch.nn.init.xavier_normal_(self.token_wV)
def forward(self, x, mask=None):
B_, N, C = x.shape
wa = repeat(self.token_wA, '() n d -> b n d', b=B_) # wa (bs 4 64)
wa = rearrange(wa, 'b h w -> b w h') # Transpose # wa (bs 64 4)
A = torch.einsum('bij,bjk->bik', x, wa) # A (bs 81 4)
A = rearrange(A, 'b h w -> b w h') # Transpose # A (bs 4 81)
A = A.softmax(dim=-1)
VV = repeat(self.token_wV, '() n d -> b n d', b=B_) # VV(bs,64,64)
VV = torch.einsum('bij,bjk->bik', x, VV) # VV(bs,81,64)
x = torch.einsum('bij,bjk->bik', A, VV) # T(bs,4,64)
absolute_pos_bias = self.gauss_pos_index.unsqueeze(0)
q = self.qkv(x).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn + absolute_pos_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
if __name__ == '__main__':
H, W = 7, 7
x = torch.randn(4, 512, 7, 7)
x = rearrange(x, 'b c h w -> b (h w) c')
model = LSGAttention(512, (7, 7))
output = model(x)
output = rearrange(output, 'b (h w) c -> b c h w', h=H, w=W)
print(output.shape)