一、研究背景
现有高光谱图像重建方法的不足:
1.模型类方法
:依赖手工图像先验,需要手动调整参数,重建速度慢,且表示能力和泛化
能力有限。
2.即插即用算法
:将预训练的去噪网络插入传统模型方法中,但预训练网络固定不重新训
练,性能受限。
3.端到端算法
:通常采用卷积神经网络(
CNN
),学习从测量到期望高光谱图像的端到端
映射函数,但忽略了
CASSI
系统的工作原理,缺乏理论证明、可解释性和灵活性。
4.深度展开方法
:采用多阶段网络将测量映射到高光谱立方体,但现有方法存在不估计
CASSI
退化模式、主要基于
CNN
在捕获非局部自相似性和长程依赖方面有限等问题。
二、方法
1. Half-Shuffle Transformer
整体结构:采用三级
U
形结构,由
Half-Shuffle Attention Block (HSAB)
构建。首先,用
一个
3*3
卷积将重塑后的
X
k
与拉伸后的β
k
映射为特征
X
0
。然后,
X
0
通过编码器、瓶颈和解
码器被嵌入为深度特征
X
d
,编码器或解码器的每个级别包含一个
HSAB
和一个调整大小的
模块。最后,一个
3*3
卷积
作用于
X
d
生成
残差图像
R
,输出去噪图像
Z
k
为
Xk
与重塑后的
R
之和。
HSAB
组成
:
HSAB
由两个层归一化(
LN
)、一个
Half-Shuffle Multi-head Self-Attention
(HS-MSA)
模块和一个前馈网络(
FFN
)组成。下采样和上采样模块分别是步长为
4*4
的卷积
和
2*2
的反卷积
三、代码实现
1.张量初始化
# 将一个张量初始化为符合截断正态分布的值,值的范围被限制在 [a, b] 之间
def _no_grad_trunc_normal_(tensor, mean, std, a, b): # 张量 均值 标准差 正态分部的上下界
def norm_cdf(x): # 计算正态分布的累积分布函数(CDF),使用误差函数(erf)
return (1. + math.erf(x / math.sqrt(2.))) / 2.
# 如果均值超出 [a, b] 范围的 2 个标准差之外,发出警告,表明分布可能不准确。
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2) # 等于2 使得警告信息能够更准确地指示出问题发生的代码位置
# 初始化
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std) # l 和 u: 计算截断点 a 和 b 对应的 CDF 值
tensor.uniform_(2 * l - 1, 2 * u - 1) # 将张量初始化为在 2 * l - 1 和 2 * u - 1 之间均匀分布的值。
# 使用逆误差函数(erfinv)将这些均匀值转换为正态分布,缩放为 std * sqrt(2) 并加上 mean
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# 确保值在 [a, b] 范围内
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
"""
希望张量中的每个元素都符合一个特定的正态分布,但只在一个指定的范围 [a, b] 内。这种分布称为“截断正态分布”,它是正态分布的一种变体,
其中的值被限制在某个区间内。这个函数确保初始化的张量的元素遵循这种分布,从而使得模型参数的初始值在合理范围内,有助于模型的训练和收敛
"""
class GELU(nn.Module):
def forward(self, x):
return F.gelu(x)
2.前馈神经网络模块 FFN
class FeedForward(nn.Module):
def __init__(self, dim, mult=4): # 通过设置 mult,你可以调整隐藏层的宽度
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
GELU(),
nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
GELU(),
nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
)
def forward(self, x):
"""
x: [b,h,w,c]
return out: [b,h,w,c]
"""
out = self.net(x.permute(0, 3, 1, 2))
return out.permute(0, 2, 3, 1)
3.HS-MSA
# Half-Shuffle Multi-head Self-Attention (HS-MSA)
class HS_MSA(nn.Module):
def __init__(
self,
dim, # 特征维度
window_size=(8, 8),# 窗口大小
dim_head=28, # 每个注意力头的维度
heads=8, # 注意力头的数量
only_local_branch=False # 一个布尔值,指示是否仅使用局部分支
):
super().__init__()
self.dim = dim
self.heads = heads
self.scale = dim_head ** -0.5 # 缩放因子
self.window_size = window_size
self.only_local_branch = only_local_branch
# position embedding
if only_local_branch:
seq_l = window_size[0] * window_size[1] # 窗口内的序列长度 窗口大小为 8x8,那么 seq_l 就是 64。
# nn.Parameter 使得张量成为模型的一部分,并且能够被自动地包含在优化过程中
self.pos_emb = nn.Parameter(torch.Tensor(1, heads, seq_l, seq_l)) # 创建一个形状为 (1, heads, seq_l, seq_l) 的张量,作为位置嵌入参数
trunc_normal_(self.pos_emb)# 初始化为正态分布的随机数(有范围)
else:
seq_l1 = window_size[0] * window_size[1]
# 创建一个形状为 (1, 1, heads//2, seq_l1, seq_l1) 的张量
# 1 是批量维度。1 表示只有一个局部位置嵌入分支。
# heads//2 表示注意力头的数量除以2(因为这里的分支数为2)。
self.pos_emb1 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l1, seq_l1))
# 将整体特征图的尺寸 256(高度)和 320(宽度)按注意力头的数量 self.heads 进行划分。
# 这样,每个注意力头处理的特征图就会有 h 行和 w 列的尺寸 可以确保每个头处理的特征图具有适当的空间分辨率
h,w = 256//self.heads,320//self.heads
seq_l2 = h*w//seq_l1
self.pos_emb2 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l2, seq_l2))
trunc_normal_(self.pos_emb1)
trunc_normal_(self.pos_emb2)
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim) # 用于将多头注意力的输出转换回输入特征维度
def forward(self, x):
"""
x: [b,h,w,c]
return out: [b,h,w,c]
"""
b, h, w, c = x.shape
w_size = self.window_size
assert h % w_size[0] == 0 and w % w_size[1] == 0, 'fmap dimensions must be divisible by the window size'
# 分支
if self.only_local_branch:
# 假设 w_size = (4, 4),那么 b0 和 b1 都是 4。在这种情况下,
# 输入张量 x 的形状可能是 [b, h * 4, w * 4, c],
# 而 x_inp 的形状将会是 [b * h * w, 16, c],其中 16 是 4 * 4 的结果
x_inp = rearrange(x, 'b (h b0) (w b1) c -> (b h w) (b0 b1) c', b0=w_size[0], b1=w_size[1])
q = self.to_q(x_inp)
k, v = self.to_kv(x_inp).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
q *= self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) # 张量乘积和求和
sim = sim + self.pos_emb # 加入嵌入位置
attn = sim.softmax(dim=-1) # -1,最后一个维度 j 上应用 softmax 函数
out = einsum('b h i j, b h j d -> b h i d', attn, v) # 加权求和
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out) # 线性变换, 将 out 张量的最后一个维度映射到最终的输出维度
out = rearrange(out, '(b h w) (b0 b1) c -> b (h b0) (w b1) c', h=h // w_size[0], w=w // w_size[1],
b0=w_size[0])
else:
q = self.to_q(x)
k, v = self.to_kv(x).chunk(2, dim=-1)
q1, q2 = q[:,:,:,:c//2], q[:,:,:,c//2:]
k1, k2 = k[:,:,:,:c//2], k[:,:,:,c//2:]
v1, v2 = v[:,:,:,:c//2], v[:,:,:,c//2:]
# local branch
q1, k1, v1 = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
b0=w_size[0], b1=w_size[1]), (q1, k1, v1))
q1, k1, v1 = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads//2), (q1, k1, v1))
q1 *= self.scale
sim1 = einsum('b n h i d, b n h j d -> b n h i j', q1, k1) # 点积
sim1 = sim1 + self.pos_emb1
attn1 = sim1.softmax(dim=-1)
out1 = einsum('b n h i j, b n h j d -> b n h i d', attn1, v1)# 加权求和
out1 = rearrange(out1, 'b n h mm d -> b n mm (h d)')
# non-local branch
q2, k2, v2 = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
b0=w_size[0], b1=w_size[1]), (q2, k2, v2))
q2, k2, v2 = map(lambda t: t.permute(0, 2, 1, 3), (q2.clone(), k2.clone(), v2.clone()))
q2, k2, v2 = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads//2), (q2, k2, v2))
q2 *= self.scale
sim2 = einsum('b n h i d, b n h j d -> b n h i j', q2, k2)
sim2 = sim2 + self.pos_emb2
attn2 = sim2.softmax(dim=-1)
out2 = einsum('b n h i j, b n h j d -> b n h i d', attn2, v2)
out2 = rearrange(out2, 'b n h mm d -> b n mm (h d)')
out2 = out2.permute(0, 2, 1, 3)
out = torch.cat([out1,out2],dim=-1).contiguous()
out = self.to_out(out)
out = rearrange(out, 'b (h w) (b0 b1) c -> b (h b0) (w b1) c', h=h // w_size[0], w=w // w_size[1],
b0=w_size[0])
return out
4.HSAB
# Half-Shuffle Attention Block (HSAB)
class HSAB(nn.Module):
def __init__(
self,
dim,
window_size=(8, 8),
dim_head=64,
heads=8,
num_blocks=2,
):
super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(num_blocks):
self.blocks.append(nn.ModuleList([
PreNorm(dim, HS_MSA(dim=dim, window_size=window_size, dim_head=dim_head, heads=heads, only_local_branch=(heads==1))),
PreNorm(dim, FeedForward(dim=dim))
]))
def forward(self, x):
"""
x: [b,c,h,w]
return out: [b,c,h,w]
"""
x = x.permute(0, 2, 3, 1)
for (attn, ff) in self.blocks:
x = attn(x) + x
x = ff(x) + x
out = x.permute(0, 3, 1, 2)
return out
5.HST框架
# 框架
class HST(nn.Module):
def __init__(self, in_dim=28, out_dim=28, dim=28, num_blocks=[1,1,1]):
super(HST, self).__init__()
self.dim = dim
self.scales = len(num_blocks)
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
dim_scale = dim
for i in range(self.scales-1): # 循环的次数取决于 self.scales
self.encoder_layers.append(nn.ModuleList([
HSAB(dim=dim_scale, num_blocks=num_blocks[i], dim_head=dim, heads=dim_scale // dim),
# 下采样
nn.Conv2d(dim_scale, dim_scale * 2, 4, 2, 1, bias=False),
]))
dim_scale *= 2 # 每次循环后,将 dim_scale 乘以 2,为下一层做准备
# Bottleneck
self.bottleneck = HSAB(dim=dim_scale, dim_head=dim, heads=dim_scale // dim, num_blocks=num_blocks[-1])
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(self.scales-1):
self.decoder_layers.append(nn.ModuleList([
# 上采样
nn.ConvTranspose2d(dim_scale, dim_scale // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
# 1*1 的卷积 连接
nn.Conv2d(dim_scale, dim_scale // 2, 1, 1, bias=False),
HSAB(dim=dim_scale // 2, num_blocks=num_blocks[self.scales - 2 - i], dim_head=dim,
heads=(dim_scale // 2) // dim),
]))
dim_scale //= 2
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
#### activation function
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
b, c, h_inp, w_inp = x.shape
hb, wb = 16, 16 # 希望填充后的高度和宽度都是 16 的倍数
# pad_h 和 pad_w 分别是计算需要填充的高度和宽度。这个计算确保填充后的尺寸是 16 的倍数。
pad_h = (hb - h_inp % hb) % hb
pad_w = (wb - w_inp % wb) % wb
# F.pad 函数对张量 x 进行填充
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
"""
[0, pad_w, 0, pad_h] 表示在宽度方向上在右边填充 pad_w,在高度方向上在底部填充 pad_h。
填充模式为 'reflect',即反射填充,即边缘像素值反射填充到新增的区域。
"""
# Embedding
fea = self.embedding(x)
x = x[:,:28,:,:]
# Encoder
fea_encoder = []
for (HSAB, FeaDownSample) in self.encoder_layers:
fea = HSAB(fea)
fea_encoder.append(fea)
fea = FeaDownSample(fea)
# Bottleneck
fea = self.bottleneck(fea)
# Decoder
for i, (FeaUpSample, Fution, HSAB) in enumerate(self.decoder_layers):
fea = FeaUpSample(fea)
fea = Fution(torch.cat([fea, fea_encoder[self.scales-2-i]], dim=1))
fea = HSAB(fea)
# Mapping
out = self.mapping(fea) + x
return out[:, :, :h_inp, :w_inp]
填充高度和宽度:pad_h = (hb - h_inp % hb) % hb ,pad_w = (wb - w_inp % wb) % wb
6.定义一些基本操作:
# 线性变换
def A(x,Phi):
temp = x*Phi
y = torch.sum(temp,1)
return y
def At(y,Phi):
temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1],1,1)
x = temp*Phi
return x
# 对输入的三维张量 inputs 进行移位操作
# 对输入张量的每个通道在列方向上进行独立的移位操作
def shift_3d(inputs,step=2):
[bs, nC, row, col] = inputs.shape
for i in range(nC): # 当 dims=2 时,表示在第 3 个维度(从 0 开始计数)上进行滚动,也就是在列的方向上进行移位 roll滚动,位移
inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=step*i, dims=2)
return inputs
# 对输入张量 inputs 的每个通道进行反向滚动(移位),将其恢复到原始状态。它实际上是 shift_3d 函数的逆操作。
def shift_back_3d(inputs,step=2):
[bs, nC, row, col] = inputs.shape
for i in range(nC):
inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=(-1)*step*i, dims=2)
return inputs
7.神经网络模块
# 神经网络模块
class HyPaNet(nn.Module):
def __init__(self, in_nc=29, out_nc=8, channel=64):
super(HyPaNet, self).__init__()
self.fution = nn.Conv2d(in_nc, channel, 1, 1, 0, bias=True)
self.down_sample = nn.Conv2d(channel, channel, 3, 2, 1, bias=True)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.mlp = nn.Sequential(
nn.Conv2d(channel, channel, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel, out_nc, 1, padding=0, bias=True),
nn.Softplus())
self.relu = nn.ReLU(inplace=True)
self.out_nc = out_nc
def forward(self, x):
x = self.down_sample(self.relu(self.fution(x)))
x = self.avg_pool(x)
x = self.mlp(x) + 1e-6 # 将输出分为两部分:前半部分和后半部分,根据 out_nc 分成两半。
return x[:,:self.out_nc//2,:,:], x[:,self.out_nc//2:,:,:
8. 光谱压缩成像的模型
# 光谱压缩成像的模型
class DAUHST(nn.Module):
def __init__(self, num_iterations=1):
super(DAUHST, self).__init__()
# 被用作参数估计器,输入通道为28,输出通道为 num_iterations * 2
self.para_estimator = HyPaNet(in_nc=28, out_nc=num_iterations*2)
# 将通道数从56减少到28
self.fution = nn.Conv2d(56, 28, 1, padding=0, bias=True)
# 迭代次数
self.num_iterations = num_iterations
self.denoisers = nn.ModuleList([])
for _ in range(num_iterations):
self.denoisers.append(
HST(in_dim=29, out_dim=28, dim=28, num_blocks=[1,1,1]),
)
# 根据输入的压缩测量y和感知矩阵Phi进行初始化操作
def initial(self, y, Phi):
"""
:param y: [b,256,310] 表示批量图像数据
:param Phi: [b,28,256,310] 通常表示感知矩阵(也就是变换矩阵)
:return: temp: [b,28,256,310]; alpha: [b, num_iterations]; beta: [b, num_iterations]
"""
nC, step = 28, 2
y = y / nC * 2 # 将 y 的值归一化到 [0, 2] 的范围
bs,row,col = y.shape
# y_shift 是一个新的张量,其形状为 [b, 28, 256, 310],初始化为全零
y_shift = torch.zeros(bs, nC, row, col).cuda().float()
# 通过循环,将 y 的不同片段赋值给 y_shift。
# 每个片段的起始位置由 step * i 决定,片段的宽度会减去 (nC - 1) * step 以确保对齐。
for i in range(nC):
y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step]
# 将 y_shift 和 Phi 沿通道维度拼接,然后通过 fution 卷积层得到 z
# 这一步是通过卷积操作将两个输入融合成新的特征图 z
z = self.fution(torch.cat([y_shift, Phi], dim=1))
# 再次拼接 y_shift 和 Phi,然后通过 para_estimator 网络模块计算 alpha 和 beta
alpha, beta = self.para_estimator(self.fution(torch.cat([y_shift, Phi], dim=1)))
return z, alpha, beta # 这些参数用于后续的迭代处理
def forward(self, y, input_mask=None):
""" P 每个阶段的线性投影网络 D 去噪网络
:param y: [b,256,310]
:param Phi: [b,28,256,310]
:param Phi_PhiT: [b,256,310]
:return: z_crop: [b,28,256,256]
"""
Phi, Phi_s = input_mask
z, alphas, betas = self.initial(y, Phi)
for i in range(self.num_iterations):
# 参数提取
alpha, beta = alphas[:,i,:,:], betas[:,i:i+1,:,:]
# 线性变换
Phi_z = A(z, Phi)
# div 逐元素除法的函数--> 计算 (y - Phi_z) / (alpha + Phi_s)
x = z + At(torch.div(y-Phi_z,alpha+Phi_s), Phi)
x = shift_back_3d(x)
# 将 beta 扩展到与 x 的形状一致。
beta_repeat = beta.repeat(1,1,x.shape[2], x.shape[3])
# 将 x 和扩展的 beta 拼接,并通过当前的去噪网络 self.denoisers[i] 处理
# HST
z = self.denoisers[i](torch.cat([x, beta_repeat],dim=1))
# 在最后一次迭代之前,通过 shift_3d 调整 z 的形状
if i<self.num_iterations-1:
z = shift_3d(z)
return z[:, :, :, 0:256]