Degradation-Aware Unfolding Half-Shuffle Transformer for Spectral Compressive Imaging

一、研究背景

现有高光谱图像重建方法的不足:
1.模型类方法 :依赖手工图像先验,需要手动调整参数,重建速度慢,且表示能力和泛化
能力有限。
2.即插即用算法 :将预训练的去噪网络插入传统模型方法中,但预训练网络固定不重新训
练,性能受限。
3.端到端算法 :通常采用卷积神经网络( CNN ),学习从测量到期望高光谱图像的端到端
映射函数,但忽略了 CASSI 系统的工作原理,缺乏理论证明、可解释性和灵活性。
4.深度展开方法 :采用多阶段网络将测量映射到高光谱立方体,但现有方法存在不估计
CASSI 退化模式、主要基于 CNN 在捕获非局部自相似性和长程依赖方面有限等问题。

二、方法

1. Half-Shuffle Transformer
整体结构:采用三级 U 形结构,由 Half-Shuffle Attention Block (HSAB) 构建。首先,用
一个 3*3 卷积将重塑后的 X k 与拉伸后的β k 映射为特征 X 0 。然后, X 0 通过编码器、瓶颈和解
码器被嵌入为深度特征 X d ,编码器或解码器的每个级别包含一个 HSAB 和一个调整大小的
模块。最后,一个 3*3 卷积 作用于 X d 生成 残差图像 R ,输出去噪图像 Z k Xk 与重塑后的 R
之和。
HSAB 组成 HSAB 由两个层归一化( LN )、一个 Half-Shuffle Multi-head Self-Attention
(HS-MSA) 模块和一个前馈网络( FFN )组成。下采样和上采样模块分别是步长为 4*4 的卷积
2*2 的反卷积

三、代码实现 

1.张量初始化

# 将一个张量初始化为符合截断正态分布的值,值的范围被限制在 [a, b] 之间
def _no_grad_trunc_normal_(tensor, mean, std, a, b): # 张量 均值 标准差 正态分部的上下界
    def norm_cdf(x): # 计算正态分布的累积分布函数(CDF),使用误差函数(erf)
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    # 如果均值超出 [a, b] 范围的 2 个标准差之外,发出警告,表明分布可能不准确。
    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2) # 等于2 使得警告信息能够更准确地指示出问题发生的代码位置
    # 初始化
    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std) # l 和 u: 计算截断点 a 和 b 对应的 CDF 值
        tensor.uniform_(2 * l - 1, 2 * u - 1) # 将张量初始化为在 2 * l - 1 和 2 * u - 1 之间均匀分布的值。
        # 使用逆误差函数(erfinv)将这些均匀值转换为正态分布,缩放为 std * sqrt(2) 并加上 mean
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        # 确保值在 [a, b] 范围内
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
"""
希望张量中的每个元素都符合一个特定的正态分布,但只在一个指定的范围 [a, b] 内。这种分布称为“截断正态分布”,它是正态分布的一种变体,
其中的值被限制在某个区间内。这个函数确保初始化的张量的元素遵循这种分布,从而使得模型参数的初始值在合理范围内,有助于模型的训练和收敛
"""

class GELU(nn.Module):
    def forward(self, x):
        return F.gelu(x)

2.前馈神经网络模块 FFN

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4): # 通过设置 mult,你可以调整隐藏层的宽度
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
            GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )

    def forward(self, x):
        """
        x: [b,h,w,c]
        return out: [b,h,w,c]
        """
        out = self.net(x.permute(0, 3, 1, 2))
        return out.permute(0, 2, 3, 1)

 3.HS-MSA

# Half-Shuffle Multi-head Self-Attention (HS-MSA)
class HS_MSA(nn.Module):
    def __init__(
            self,
            dim, # 特征维度
            window_size=(8, 8),# 窗口大小
            dim_head=28, # 每个注意力头的维度
            heads=8, # 注意力头的数量
            only_local_branch=False # 一个布尔值,指示是否仅使用局部分支
    ):
        super().__init__()

        self.dim = dim
        self.heads = heads
        self.scale = dim_head ** -0.5 # 缩放因子
        self.window_size = window_size
        self.only_local_branch = only_local_branch

        # position embedding
        if only_local_branch:
            seq_l = window_size[0] * window_size[1] # 窗口内的序列长度 窗口大小为 8x8,那么 seq_l 就是 64。
            # nn.Parameter 使得张量成为模型的一部分,并且能够被自动地包含在优化过程中
            self.pos_emb = nn.Parameter(torch.Tensor(1, heads, seq_l, seq_l)) # 创建一个形状为 (1, heads, seq_l, seq_l) 的张量,作为位置嵌入参数
            trunc_normal_(self.pos_emb)# 初始化为正态分布的随机数(有范围)
        else:
            seq_l1 = window_size[0] * window_size[1]
            # 创建一个形状为 (1, 1, heads//2, seq_l1, seq_l1) 的张量
            # 1 是批量维度。1 表示只有一个局部位置嵌入分支。
            # heads//2 表示注意力头的数量除以2(因为这里的分支数为2)。
            self.pos_emb1 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l1, seq_l1))

            # 将整体特征图的尺寸 256(高度)和 320(宽度)按注意力头的数量 self.heads 进行划分。
            # 这样,每个注意力头处理的特征图就会有 h 行和 w 列的尺寸 可以确保每个头处理的特征图具有适当的空间分辨率
            h,w = 256//self.heads,320//self.heads

            seq_l2 = h*w//seq_l1
            self.pos_emb2 = nn.Parameter(torch.Tensor(1, 1, heads//2, seq_l2, seq_l2))

            trunc_normal_(self.pos_emb1)
            trunc_normal_(self.pos_emb2)

        inner_dim = dim_head * heads
        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim) # 用于将多头注意力的输出转换回输入特征维度

    def forward(self, x):
        """
        x: [b,h,w,c]
        return out: [b,h,w,c]
        """
        b, h, w, c = x.shape
        w_size = self.window_size
        assert h % w_size[0] == 0 and w % w_size[1] == 0, 'fmap dimensions must be divisible by the window size'
        # 分支
        if self.only_local_branch:
        #  假设 w_size = (4, 4),那么 b0 和 b1 都是 4。在这种情况下,
        #  输入张量 x 的形状可能是 [b, h * 4, w * 4, c],
        #  而 x_inp 的形状将会是 [b * h * w, 16, c],其中 16 是 4 * 4 的结果
            x_inp = rearrange(x, 'b (h b0) (w b1) c -> (b h w) (b0 b1) c', b0=w_size[0], b1=w_size[1])
            q = self.to_q(x_inp)
            k, v = self.to_kv(x_inp).chunk(2, dim=-1)
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), (q, k, v))
            q *= self.scale

            sim = einsum('b h i d, b h j d -> b h i j', q, k) # 张量乘积和求和
            sim = sim + self.pos_emb # 加入嵌入位置
            attn = sim.softmax(dim=-1) # -1,最后一个维度 j 上应用 softmax 函数
            out = einsum('b h i j, b h j d -> b h i d', attn, v) # 加权求和
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out) # 线性变换, 将 out 张量的最后一个维度映射到最终的输出维度
            out = rearrange(out, '(b h w) (b0 b1) c -> b (h b0) (w b1) c', h=h // w_size[0], w=w // w_size[1],
                            b0=w_size[0])
        else:
            q = self.to_q(x)
            k, v = self.to_kv(x).chunk(2, dim=-1)
            q1, q2 = q[:,:,:,:c//2], q[:,:,:,c//2:]
            k1, k2 = k[:,:,:,:c//2], k[:,:,:,c//2:]
            v1, v2 = v[:,:,:,:c//2], v[:,:,:,c//2:]

            # local branch
            q1, k1, v1 = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
                                              b0=w_size[0], b1=w_size[1]), (q1, k1, v1))
            q1, k1, v1 = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads//2), (q1, k1, v1))
            q1 *= self.scale
            sim1 = einsum('b n h i d, b n h j d -> b n h i j', q1, k1) # 点积
            sim1 = sim1 + self.pos_emb1
            attn1 = sim1.softmax(dim=-1)
            out1 = einsum('b n h i j, b n h j d -> b n h i d', attn1, v1)# 加权求和
            out1 = rearrange(out1, 'b n h mm d -> b n mm (h d)')

            # non-local branch
            q2, k2, v2 = map(lambda t: rearrange(t, 'b (h b0) (w b1) c -> b (h w) (b0 b1) c',
                                                 b0=w_size[0], b1=w_size[1]), (q2, k2, v2))
            q2, k2, v2 = map(lambda t: t.permute(0, 2, 1, 3), (q2.clone(), k2.clone(), v2.clone()))
            q2, k2, v2 = map(lambda t: rearrange(t, 'b n mm (h d) -> b n h mm d', h=self.heads//2), (q2, k2, v2))
            q2 *= self.scale
            sim2 = einsum('b n h i d, b n h j d -> b n h i j', q2, k2)
            sim2 = sim2 + self.pos_emb2
            attn2 = sim2.softmax(dim=-1)
            out2 = einsum('b n h i j, b n h j d -> b n h i d', attn2, v2)
            out2 = rearrange(out2, 'b n h mm d -> b n mm (h d)')
            out2 = out2.permute(0, 2, 1, 3)

            out = torch.cat([out1,out2],dim=-1).contiguous()
            out = self.to_out(out)
            out = rearrange(out, 'b (h w) (b0 b1) c -> b (h b0) (w b1) c', h=h // w_size[0], w=w // w_size[1],
                            b0=w_size[0])

        return out

 4.HSAB

# Half-Shuffle Attention Block (HSAB)
class HSAB(nn.Module):
    def __init__(
            self,
            dim,
            window_size=(8, 8),
            dim_head=64,
            heads=8,
            num_blocks=2,
    ):
        super().__init__()
        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(nn.ModuleList([
                PreNorm(dim, HS_MSA(dim=dim, window_size=window_size, dim_head=dim_head, heads=heads, only_local_branch=(heads==1))),
                PreNorm(dim, FeedForward(dim=dim))
            ]))

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out: [b,c,h,w]
        """
        x = x.permute(0, 2, 3, 1)
        for (attn, ff) in self.blocks:
            x = attn(x) + x
            x = ff(x) + x
        out = x.permute(0, 3, 1, 2)
        return out

5.HST框架

# 框架
class HST(nn.Module):
    def __init__(self, in_dim=28, out_dim=28, dim=28, num_blocks=[1,1,1]):
        super(HST, self).__init__()
        self.dim = dim
        self.scales = len(num_blocks)

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_scale = dim
        for i in range(self.scales-1): # 循环的次数取决于 self.scales
            self.encoder_layers.append(nn.ModuleList([
                HSAB(dim=dim_scale, num_blocks=num_blocks[i], dim_head=dim, heads=dim_scale // dim),
                # 下采样
                nn.Conv2d(dim_scale, dim_scale * 2, 4, 2, 1, bias=False),
            ]))
            dim_scale *= 2 # 每次循环后,将 dim_scale 乘以 2,为下一层做准备

        # Bottleneck
        self.bottleneck = HSAB(dim=dim_scale, dim_head=dim, heads=dim_scale // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(self.scales-1):
            self.decoder_layers.append(nn.ModuleList([
                # 上采样
                nn.ConvTranspose2d(dim_scale, dim_scale // 2, stride=2, kernel_size=2, padding=0, output_padding=0),
                # 1*1 的卷积 连接
                nn.Conv2d(dim_scale, dim_scale // 2, 1, 1, bias=False),

                HSAB(dim=dim_scale // 2, num_blocks=num_blocks[self.scales - 2 - i], dim_head=dim,
                     heads=(dim_scale // 2) // dim),
            ]))
            dim_scale //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        #### activation function
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """

        b, c, h_inp, w_inp = x.shape
        hb, wb = 16, 16 # 希望填充后的高度和宽度都是 16 的倍数
        # pad_h 和 pad_w 分别是计算需要填充的高度和宽度。这个计算确保填充后的尺寸是 16 的倍数。
        pad_h = (hb - h_inp % hb) % hb
        pad_w = (wb - w_inp % wb) % wb
        # F.pad 函数对张量 x 进行填充
        x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
        """
        [0, pad_w, 0, pad_h] 表示在宽度方向上在右边填充 pad_w,在高度方向上在底部填充 pad_h。
        填充模式为 'reflect',即反射填充,即边缘像素值反射填充到新增的区域。
        """

        # Embedding
        fea = self.embedding(x)
        x = x[:,:28,:,:]

        # Encoder
        fea_encoder = []
        for (HSAB, FeaDownSample) in self.encoder_layers:
            fea = HSAB(fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)

        # Bottleneck
        fea = self.bottleneck(fea)

        # Decoder
        for i, (FeaUpSample, Fution, HSAB) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(torch.cat([fea, fea_encoder[self.scales-2-i]], dim=1))
            fea = HSAB(fea)

        # Mapping
        out = self.mapping(fea) + x
        return out[:, :, :h_inp, :w_inp]

填充高度和宽度:pad_h = (hb - h_inp % hb) % hb ,pad_w = (wb - w_inp % wb) % wb

6.定义一些基本操作:


# 线性变换
def A(x,Phi):
    temp = x*Phi
    y = torch.sum(temp,1)
    return y

def At(y,Phi):
    temp = torch.unsqueeze(y, 1).repeat(1,Phi.shape[1],1,1)
    x = temp*Phi
    return x

# 对输入的三维张量 inputs 进行移位操作
# 对输入张量的每个通道在列方向上进行独立的移位操作
def shift_3d(inputs,step=2):
    [bs, nC, row, col] = inputs.shape
    for i in range(nC): # 当 dims=2 时,表示在第 3 个维度(从 0 开始计数)上进行滚动,也就是在列的方向上进行移位  roll滚动,位移
        inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=step*i, dims=2)
    return inputs
# 对输入张量 inputs 的每个通道进行反向滚动(移位),将其恢复到原始状态。它实际上是 shift_3d 函数的逆操作。
def shift_back_3d(inputs,step=2):
    [bs, nC, row, col] = inputs.shape
    for i in range(nC):
        inputs[:,i,:,:] = torch.roll(inputs[:,i,:,:], shifts=(-1)*step*i, dims=2)
    return inputs

7.神经网络模块

# 神经网络模块
class HyPaNet(nn.Module):
    def __init__(self, in_nc=29, out_nc=8, channel=64):
        super(HyPaNet, self).__init__()
        self.fution = nn.Conv2d(in_nc, channel, 1, 1, 0, bias=True)
        self.down_sample = nn.Conv2d(channel, channel, 3, 2, 1, bias=True)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
                nn.Conv2d(channel, channel, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, out_nc, 1, padding=0, bias=True),
                nn.Softplus())
        self.relu = nn.ReLU(inplace=True)
        self.out_nc = out_nc

    def forward(self, x):
        x = self.down_sample(self.relu(self.fution(x)))
        x = self.avg_pool(x)
        x = self.mlp(x) + 1e-6 # 将输出分为两部分:前半部分和后半部分,根据 out_nc 分成两半。
        return x[:,:self.out_nc//2,:,:], x[:,self.out_nc//2:,:,:

8. 光谱压缩成像的模型

# 光谱压缩成像的模型
class DAUHST(nn.Module):

    def __init__(self, num_iterations=1):
        super(DAUHST, self).__init__()
        # 被用作参数估计器,输入通道为28,输出通道为 num_iterations * 2
        self.para_estimator = HyPaNet(in_nc=28, out_nc=num_iterations*2)
        # 将通道数从56减少到28
        self.fution = nn.Conv2d(56, 28, 1, padding=0, bias=True)
        # 迭代次数
        self.num_iterations = num_iterations

        self.denoisers = nn.ModuleList([])
        for _ in range(num_iterations):
            self.denoisers.append(
                HST(in_dim=29, out_dim=28, dim=28, num_blocks=[1,1,1]),
            )
    # 根据输入的压缩测量y和感知矩阵Phi进行初始化操作
    def initial(self, y, Phi):
        """
        :param y: [b,256,310] 表示批量图像数据
        :param Phi: [b,28,256,310] 通常表示感知矩阵(也就是变换矩阵)
        :return: temp: [b,28,256,310]; alpha: [b, num_iterations]; beta: [b, num_iterations]
        """
        nC, step = 28, 2
        y = y / nC * 2 # 将 y 的值归一化到 [0, 2] 的范围
        bs,row,col = y.shape
        # y_shift 是一个新的张量,其形状为 [b, 28, 256, 310],初始化为全零
        y_shift = torch.zeros(bs, nC, row, col).cuda().float()
        # 通过循环,将 y 的不同片段赋值给 y_shift。
        # 每个片段的起始位置由 step * i 决定,片段的宽度会减去 (nC - 1) * step 以确保对齐。
        for i in range(nC):
            y_shift[:, i, :, step * i:step * i + col - (nC - 1) * step] = y[:, :, step * i:step * i + col - (nC - 1) * step]
        # 将 y_shift 和 Phi 沿通道维度拼接,然后通过 fution 卷积层得到 z
        # 这一步是通过卷积操作将两个输入融合成新的特征图 z
        z = self.fution(torch.cat([y_shift, Phi], dim=1))
        # 再次拼接 y_shift 和 Phi,然后通过 para_estimator 网络模块计算 alpha 和 beta
        alpha, beta = self.para_estimator(self.fution(torch.cat([y_shift, Phi], dim=1)))
        return z, alpha, beta # 这些参数用于后续的迭代处理

    def forward(self, y, input_mask=None):
        """ P 每个阶段的线性投影网络  D  去噪网络
        :param y: [b,256,310]
        :param Phi: [b,28,256,310]
        :param Phi_PhiT: [b,256,310]
        :return: z_crop: [b,28,256,256]
        """
        Phi, Phi_s = input_mask
        z, alphas, betas = self.initial(y, Phi)

        for i in range(self.num_iterations):
            # 参数提取
            alpha, beta = alphas[:,i,:,:], betas[:,i:i+1,:,:]
            # 线性变换
            Phi_z = A(z, Phi)
            # div 逐元素除法的函数-->  计算 (y - Phi_z) / (alpha + Phi_s)
            x = z + At(torch.div(y-Phi_z,alpha+Phi_s), Phi)
            x = shift_back_3d(x)
            # 将 beta 扩展到与 x 的形状一致。
            beta_repeat = beta.repeat(1,1,x.shape[2], x.shape[3])
            # 将 x 和扩展的 beta 拼接,并通过当前的去噪网络 self.denoisers[i] 处理
            # HST
            z = self.denoisers[i](torch.cat([x, beta_repeat],dim=1))
            # 在最后一次迭代之前,通过 shift_3d 调整 z 的形状
            if i<self.num_iterations-1:
                z = shift_3d(z)
        return z[:, :, :, 0:256]

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值