一、MST框架
1.初始化信号处理:
首先,从测量数据中生成初始化的高光谱图像信号 (H)。
2.编码器阶段:
将H输入模型。首先,MST利用conv3×3层将H映射为特征X0。接着, X0经过N1个MSAB块、一个下采样模块 、N2个MSAB块和一个下采样模块 生成 分层特征。(下采样模块是一个步长为4×4的卷积层,用于降低特征图的尺寸 并增加通道数量) 逐步提取并减少特征的空间维度。
3.瓶颈:
X2 通过一个包含N3个MSAB块的瓶颈(这里主要对特征进行压缩和提取最 重要的特征信息)
4.解码器阶段:
按照U-Net的精神设计了对称的解码器结构。 使用上采样操作将瓶颈层 的特征进行逐步恢复和扩展。利用跳跃连接(skipconnections)从编码器的相应 阶段获取信息 上采样模块是一个步长为2×2的反卷积层。跳跃连接用于在编码器和解码器 之间进行特征聚合,减轻由下采样操作引起的信息丢失
5.重建HSIs:
特征图通过卷积层(一层conv3×3层)将解码器的输出转换为残差HSIs(R)。
将这些残差添加到初始信号 (H) 中,得到重建的高光谱图像 (H'=H+R)。
6.模型变体和优化:
MST 根据不同的设计参数和模型变体(如MST-S、MST-M、MST-L)进行优 化和适应不同的计算需求和复杂度。
MSAB: MST 的基本单元是MSAB,包括两次层归一化 、一个掩模引导的频谱注意力 模块(MS-MSA)和一个前馈神经网络(FFN)
二、核心创新
1.Spectral-wise Multi-head Self-Attention (S-MSA)
在HSI 领域,传统的基于CNN的方法往往难以很好地捕捉 数据中存在的非 局部自相似性信息。为了解决这一挑战,本文提出了利用Transformer架构和自 注意力机制来更好地建模高光谱数据中的非局部依赖关系。
-
设计理念:S-MSA将每个光谱特征视为一个token,并专注于沿光谱维度计算自注意力。这种设计旨在捕捉高光谱图像(HSI)中不同光谱通道之间的相似性和依赖性。
-
操作流程:
- 输入重塑:输入特征被重塑为token形式,使得每个光谱通道的特征能够被视为一个独立的token。
- 线性投影:通过线性变换生成查询(Q)、键(K)和值(V),这些都是用于计算自注意力的基本组件。
- 头的分割:将Q、K和V分成多个头,以便在光谱维度上并行计算自注意力。
- 自注意力计算:对于每个头,计算自注意力,通过对Q和K的点积进行softmax操作,得到注意力权重,然后将这些权重应用于V。
- 输出拼接:将所有头的输出拼接在一起,并通过线性变换生成最终的输出特征。
# Multi-Scale Multi-Head Self Attention模块 问题一 MS_MSA
class MS_MSA(nn.Module):
def __init__(
self,
dim, # 输入向量的维度
dim_head=64, # 每个注意力头的维度
heads=8, # 注意力头的数量
):
super().__init__()
self.num_heads = heads # 注意力头的数量
self.dim_head = dim_head # 每个头的维度
# Q K V 是线性层,用于将输入向量 x 映射到查询(Q)、键(K)、值(V)的空间
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
# 是一个可学习的参数,用于缩放注意力权重
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
# 线性层,用于将经过注意力计算后的输出映射回原始维度
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
# 包含两个卷积层和GELU激活函数的序列,用于生成位置编码
self.pos_emb = nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
GELU(),
nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
)
self.mm = MaskGuidedMechanism(dim)
self.dim = dim
def forward(self, x_in, mask=None):
"""
x_in: [b,h,w,c]
mask: [1,h,w,c]
return out: [b,h,w,c]
"""
b, h, w, c = x_in.shape
x = x_in.reshape(b,h*w,c)
# 将输入向量映射为查询(Q)、键(K)、值(V)的结果
q_inp = self.to_q(x)
k_inp = self.to_k(x)
v_inp = self.to_v(x)
# 处理掩码信息,如果有的话
mask_attn = self.mm(mask.permute(0,3,1,2)).permute(0,2,3,1)
if b != 0:
mask_attn = (mask_attn[0, :, :, :]).expand([b, h, w, c])
# 对 Q、K、V 进行重组以适应多头注意力计算
q, k, v, mask_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
(q_inp, k_inp, v_inp, mask_attn.flatten(1, 2)))
# 应用掩码到值(V)上
v = v * mask_attn
# q: b,heads,hw,c
# 调整 Q 的维度顺序
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)
# 归一化 Q 和 K
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
# 计算注意力权重
attn = (k @ q.transpose(-2, -1)) # A = K^T*Q
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
# 计算注意力输出
x = attn @ v # b,heads,d,hw
x = x.permute(0, 3, 1, 2) # Transpose
# 重塑输出以适应原始维度
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
out_c = self.proj(x).view(b, h, w, c)
# 将位置编码应用到值(V)上并与自注意力计算结果相加
out_p = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
out = out_c + out_p
return out
2.Mask-guided Mechanism(MM)
在直接使用 Transformer 进行 HSI 恢复时,传统的 Transformer 模型可能 会注意到一些信息较少的空间区域,其中包含质量低的 HSI 表示。为了解决这 个问题,提出了一个称为 Mask-guidedMechanism(MM) 的机制
- 设计理念: MM利用CASSI系统中的物理掩模信息,指导S-MSA关注高保真光谱表示的空间区域。这种方法能够有效避免模型关注低保真区域,从而提高重建质量
- 操作流程:
- 首先,对掩模进行位移处理,以匹配CASSI系统中的光谱信息。
- 经过一系列卷积操作后,MM生成一个掩模注意力图,该图与输入特征进行结合。
- 在S-MSA中,掩模注意力图用于加权V,从而引导模型关注更有信息量的区域
# Mask Guided Mechanism模块 问题二 MM
class MaskGuidedMechanism(nn.Module):
def __init__(
self, n_feat):
super(MaskGuidedMechanism, self).__init__()
self.conv1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
self.conv2 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=True)
self.depth_conv = nn.Conv2d(n_feat, n_feat, kernel_size=5, padding=2, bias=True, groups=n_feat) # 使用 n_feat 组
# 深度卷积层,使用 5x5 的卷积核,并设置 groups=n_feat,表示每个输入通道都独立卷积
# 深度卷积可以帮助提取更复杂的特征,同时保持每个通道的特征独立性
def forward(self, mask_shift):
# x: b,c,h,w
[bs, nC, row, col] = mask_shift.shape
mask_shift = self.conv1(mask_shift)
attn_map = torch.sigmoid(self.depth_conv(self.conv2(mask_shift))) # 深度卷积层
res = mask_shift * attn_map # 注意力加权
mask_shift = res + mask_shift # 加权后结果与原始 mask_shift 相加
mask_emb = shift_back(mask_shift) # 恢复某之前的状态或调整尺寸
return mask_emb
三、代码实现
1.MSAB 模块与 FFN
# MSAB 模块
class MSAB(nn.Module):
def __init__(
self,
dim,
dim_head=64,
heads=8,
num_blocks=2,
):
super().__init__()
# 初始化多个块在一个 ModuleList 中
self.blocks = nn.ModuleList([])
for _ in range(num_blocks): # # 每个块包括一个多头自注意力模块和一个预归一化的前馈神经网络模块
self.blocks.append(nn.ModuleList([
MS_MSA(dim=dim, dim_head=dim_head, heads=heads), # 多头自注意力模块
PreNorm(dim, FeedForward(dim=dim)) # 预归一化的前馈神经网络模块
]))
def forward(self, x, mask):
"""
x: [b,c,h,w]
return out: [b,c,h,w]
"""
x = x.permute(0, 2, 3, 1) # 将张量布局改为 [b,h,w,c],以便用于注意力模块
for (attn, ff) in self.blocks:
# # 应用多头自注意力并添加残差连接
x = attn(x, mask=mask.permute(0, 2, 3, 1)) + x
# 应用前馈神经网络并添加残差连接
x = ff(x) + x
out = x.permute(0, 3, 1, 2) # 将张量布局恢复为 [b,c,h,w] 作为输出
return out
#FFN
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
GELU(),
nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
GELU(),
nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
)
def forward(self, x):
"""
x: [b,h,w,c]
return out: [b,h,w,c]
"""
"""
permute(0, 3, 1, 2) 表示将原始张量 x 的维度重新排列为:
第0维保持不变,
将第3维移到第1个位置,
将原来的第1维移到第2个位置,
将原来的第2维移到第3个位置
"""
out = self.net(x.permute(0, 3, 1, 2))
return out.permute(0, 2, 3, 1)
2.MST框架
# MST框架
class MST(nn.Module): # dim (维度大小) stage(阶段数) num_blocks(每个阶段的块数)
def __init__(self, dim=28, stage=3, num_blocks=[2,2,2]):
super(MST, self).__init__()
self.dim = dim
self.stage = stage
# Input projection 输入通道 输出通道
self.embedding = nn.Conv2d(28, self.dim, 3, 1, 1, bias=False)
# Encoder 编码器
self.encoder_layers = nn.ModuleList([])
dim_stage = dim
for i in range(stage):
self.encoder_layers.append(nn.ModuleList([
MSAB(
dim=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
# 对输入特征进行下采样
nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False),
# 类似于上述的下采样卷积层,不同之处在于它处理的是掩码(mask)数据,同样的卷积设置用于降低掩码的维度
nn.Conv2d(dim_stage, dim_stage * 2, 4, 2, 1, bias=False)
]))
# dim_stage 初始值为 dim,随着每个阶段的迭代而翻倍,确保每个阶段处理的特征维度都是前一个阶段的两倍
dim_stage *= 2
# Bottleneck 瓶颈
# 使用一个多头自注意力模块(MSAB)来处理编码器的最终输出
self.bottleneck = MSAB(
dim=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
# Decoder 解码器
self.decoder_layers = nn.ModuleList([])
for i in range(stage):
self.decoder_layers.append(nn.ModuleList([
# 上采样操作
nn.ConvTranspose2d(dim_stage, dim_stage // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
# 普通的卷积层,用于进一步处理解码器的输出
nn.Conv2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
MSAB(
dim=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
heads=(dim_stage // 2) // dim),
]))
# 在每次循环迭代中,dim_stage被减半,这是因为解码器每一层通常会减少特征图的深度和宽高尺寸
dim_stage //= 2
# Output projection 最终的输出投影层,将模型输出的特征图映射回28个通道
self.mapping = nn.Conv2d(self.dim, 28, 3, 1, 1, bias=False)
# activation function 使用LeakyReLU作为激活函数,负斜率为0.1
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x, mask=None):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
if mask == None: # None 是一个特殊的空值对象,表示没有值
mask = torch.zeros((1,28,256,310)).cuda() # 批量大小为1,具有28个通道,高度为256,宽度为310 .cuda()在cpu上运行
# Embedding
fea = self.lrelu(self.embedding(x)) # 将输入x经过LeakyReLU激活后,通过输入投影层self.embedding进行特征提取
# Encoder
# 特征fea 掩码mask
fea_encoder = []
masks = []
for (MSAB, FeaDownSample, MaskDownSample) in self.encoder_layers:
fea = MSAB(fea, mask)
masks.append(mask)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
mask = MaskDownSample(mask)
# Bottleneck
fea = self.bottleneck(fea, mask)
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
mask = masks[self.stage - 1 - i]
fea = LeWinBlcok(fea, mask)
# Fution 融合模块,将上采样后的特征与编码器对应层的特征进行融合
# LeWinBlcok: 局部窗块,使用局部窗自注意力机制,捕捉特征之间的局部依赖关系
# Mapping
out = self.mapping(fea) + x
# 最终通过输出投影层self.mapping将特征映射为模型输出,并加上原始输入x,以产生最终的输出out
return out