HCS² -Net ,用于高光谱压缩快照重建的无监督空间-光谱网络

目录

一、文章解读

 二、框架的流程

三、代码解读

 1.空间-光谱注意力模块

 2.设计网络

2.1输入

2.2 瓶颈残差块 (BRB)

2.3完整结构

 2.4测试


 

一、文章解读

  1. 问题背景: 高光谱压缩成像利用压缩感知理论,通过编码孔径快照测量来获取高光谱数据,避免了时间扫描。其核心问题是如何利用压缩感知重建算法从压缩快照测量中重建原始高光谱图像

  2. 现有方法的不足: 由于不同光谱成像设备的光谱响应特性波长范围存在差异,现有方法往往难以捕捉复杂的光谱变化,也缺乏对新高光谱成像仪的适应能力。

  3. HCS² -Net 的优势: 为了解决这些问题,文章提出了一种无监督空间-光谱网络,仅从压缩快照测量中重建高光谱图像。该网络是一个条件生成模型,以快照测量为条件,利用空间-光谱注意力模块来捕捉高光谱图像的空间-光谱相关性。网络参数经过优化,确保网络输出能够根据成像模型与给定的快照测量紧密匹配,从而使网络能够适应不同的成像设置,提高网络的适用性。

  4. 网络架构: HCS² -Net 的网络架构主要包括以下几个部分:

    • 输入: 随机编码 Z 和快照测量 Y。
    • 特征提取: 多个 1×1 瓶颈残差块 (BRB) 用于提取特征。
    • 空间-光谱注意力模块: 用于捕捉高光谱图像的空间-光谱相关性。
    • 输出: 重建后的高光谱图像。
  5. 网络学习: 网络参数通过最小化重建图像与快照测量之间的误差来进行优化。

  6. 实验结果: 文章在多个数据集上进行了实验,结果表明 HCS² -Net 比现有方法取得了更好的重建效果。

 二、框架的流程

(BN: Batch Normalization,批归一化,ReLU: Rectified Linear Unit,修正线性单元 

1. 输入:

  • 随机编码 Z: 一个随机生成的矩阵,用于对高光谱图像进行编码。
  • 快照测量 Y: 通过编码孔径对原始高光谱图像进行压缩后的二维测量结果。

2. 特征提取:Bottleneck Residual Block (BRB)

  • 1×1 瓶颈残差块 (BRB): 输入数据首先经过多个 BRB 进行特征提取。BRB 是一种特殊的残差块,它使用 1×1 卷积来减少特征维度,并通过跳跃连接原始特征提取后的特征进行融合,从而提高网络的学习效率

3. 空间-光谱注意力模块:Spatial-Spectral Attention Module

  • 注意力机制: 该模块利用注意力机制来捕捉高光谱图像的空间-光谱相关性。它通过学习一个三维注意力矩阵,对每个特征点进行加权,从而突出重要的特征信息,抑制无关的信息。
  • 多尺度融合: 该模块还采用了多尺度融合策略,将不同尺度的特征信息进行融合,进一步提高重建精度。

4. 输出:

  •  1×1 卷积的作用是 调整输出通道数,使其与高光谱图像的波段数一致 ,
  • Sigmoid 激活函数的作用是将网络的输出限制在 0 到 1 之间
  • 重建后的高光谱图像: 经过特征提取注意力机制处理后,网络输出一个重建后的高光谱图像。

5. 网络学习:

  • 损失函数: 网络学习的目标是通过最小化重建图像与快照测量之间的误差来优化网络参数。
  • 优化算法: 文章采用 Adam 优化算法来更新网络参数。

总结:

HCS² -Net 通过将随机编码和快照测量作为输入,利用 BRB 和空间-光谱注意力模块提取特征,并通过最小化重建误差来学习网络参数,最终输出重建后的高光谱图像。

三、代码解读

 1.空间-光谱注意力模块


 

#  空间-光谱注意力模块
class spatial_att_new(nn.Module):

    def __init__(self, input_channel):
        super(spatial_att_new,self).__init__()
        self.bn = nn.BatchNorm2d(input_channel)
        self.relu = nn.ReLU(inplace=True)     # inplace=True 表示原地操作,直接在原始张量上进行修改 inplace=False 或不指定表示创建一个新的张量保存操作结果,保持原始张量不变。
        
       # 用于生成空间注意力权重图  1x1卷积层 步长1 包含偏置
        self.conv_sa = nn.Conv2d(input_channel, input_channel, kernel_size=1, stride=1, bias=True)
       
        # 用于调整特征通道数   步长1 包含偏置
        self.conv_1 = nn.Conv2d(input_channel, input_channel, kernel_size=1, stride=1, bias=True)
      
        # 用于提取局部特征   3x3卷积层 步长1 填充1 包括偏置
        self.conv_3 = nn.Conv2d(input_channel, input_channel, kernel_size=3, stride=1, padding=1, bias=True)
       
        # 用于对特征图进行下采样 使用 3×3 卷积核,步长为 2,填充为 1
        self.con_stride = nn.Conv2d(input_channel, input_channel, 3, stride=2, padding=1)
        
        # 用于将不同尺度的特征进行融合  3x3卷积层,填充为1,步幅为1,没有偏置
        self.con = nn.Conv2d(input_channel * 2, input_channel, 3, 1, 1)

    def forward(self, x):
        # 对输入特征图进行下采样和卷积操作,并使用 ReLU 和 Batch Normalization 进行激活和归一化
        x0 = self.relu(self.bn(self.conv_3(self.con_stride(x))))
        # 下采样后的特征图进行 1×1 卷积操作,并使用 ReLU 和 Batch Normalization 进行激活和归一化
        x1 = self.relu(self.bn(self.conv_1(x0)))     #(b,c,h/2,w/2)

        # 对下采样后的特征图再次进行下采样和卷积操作
        x2 = self.relu(self.bn(self.conv_3(self.con_stride(x1))))
        # 下采样后特征图进行 1×1 卷积操作
        x2 = self.relu(self.bn(self.conv_1(x2)))
        # 对下采样后的特征图进行上采样操作,使用双线性插值方法
        x2 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False)


        x2 = self.conv_1(self.relu(self.bn(self.conv_1(self.conv_3(x2)))))
        # 对应文章中对上采样后的特征图进行 Sigmoid 激活,得到注意力权重
        x2_ = torch.sigmoid(x2)

        x1 = self.relu(self.bn(self.conv_3(x0)))
        x3 = x1 * x2_ * 2  # 注意力权重与下采样后的特征图进行 Hadamard 乘积,并乘以 2 进行缩放

        x3 = torch.cat([x3, x2], dim=1)  # 将注意力增强后的特征图与上采样后的特征图进行拼接
        x3 = self.relu(self.bn(self.con(x3)))  # 对拼接后的特征图进行卷积操作,并使用 ReLU 和 Batch Normalization 进行激活和归一化

        x3 = self.relu(self.bn(self.conv_1(x3)))
        # 对卷积后的特征图进行上采样操作,使用双线性插值方法
        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)   #(b,c,h,w)

        # 中对上采样后的特征图进行 3×3 卷积操作
        x3 = self.conv_3(x3)
        # 对卷积后的特征图进行 1×1 卷积操作,并使用 ReLU 和 Batch Normalization 进行激活和归一化
        x3 = self.conv_1(self.relu(self.bn(self.conv_1(x3))))
        x4 = torch.sigmoid(x3)  # 对卷积后的特征图进行 Sigmoid 激活,得到注意力权重

        # 确保 x4 和 x3 的形状与输入 x 的形状一致
        if x4.shape == x.shape and x3.shape == x.shape:
            x4 = x4
            x3 = x3
        else:
            b, c, h, w = x.size()  # b (batch): 批次大小   c (channel): 通道数 h (height): 高度 w (width): 宽度
            # 将 x4 和 x3 的形状调整为与 x 一致
            x4 = x4[:, :, :h, :w]
            x3 = x3[:, :, :h, :w]

        out = x * x4 * 2

        out = torch.cat([x3, out], dim=1)  # 对应文章中将注意力增强后的特征图与原始特征图进行拼接
        out = self.relu(self.bn(self.con(out)))  # 对拼接后的特征图进行卷积操作,并使用 ReLU 和 Batch Normalization 进行激活和归一化

        return out

 2.设计网络

2.1输入

    # 从X中取出第1个通道到第64个通道之间的部分
        noise_input = x[:, 0:64, :, :]  # torch.Size([1, 32, 512, 512])
        # 从x中取出第64通道
        y_input = x[:, 64:65, :, :]     # torch.Size([1, 1, 512, 512])

        #随机编码 Z
        noise_input = self.conv_z(noise_input)

        #快照测量 y
        y = torch.sigmoid(self.conv_y(y_input))

         # 连接 作为输入
        x = torch.cat([noise_input, y], dim=1)

       
        x = self.act(self.bn(self.conv_3(x)))

2.2 瓶颈残差块 (BRB)

  # 第一个瓶颈残差块 (BRB)
        identity1 = x
        x = self.act(self.bn(self.conv3_1(x)))
        x = self.conv3_2(x)
        x1 = x + self.conv1(identity1) # +
        x1 = self.act(self.bn(x1))
        # 对结合的x1使用 ReLU 和 Batch Normalization 进行激活和归一化

2.3完整结构

class hslnet(nn.Module):

    def __init__(self, in_channels, out_channels, middle_channels):
        super(hslnet, self).__init__()
        self.conv_3 = nn.Conv2d(in_channels, middle_channels,  3, 1, 1)
        self.conv_1 = nn.Conv2d(middle_channels, out_channels, 1, 1)

        self.att = spatial_att_new(middle_channels)
        # 是两个批标准化层,分别应用于中间通道数和输出通道数
        self.bn = nn.BatchNorm2d(middle_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

        self.conv3_1 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_2 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_3 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv3_4 = nn.Conv2d(middle_channels, middle_channels, 3, 1, padding=1, dilation=1)
        self.conv1 = nn.Conv2d(middle_channels, middle_channels, 1, 1)

        self.conv_y = nn.Conv2d(1,1, 3, 1, 1)  # 0804
        self.conv_z = nn.Conv2d(64, 64, 3, 1, 1)


    def forward(self, x):
        # 从X中取出第1个通道到第64个通道之间的部分
        noise_input = x[:, 0:64, :, :]  # torch.Size([1, 32, 512, 512])
        # 从x中取出第64通道
        y_input = x[:, 64:65, :, :]     # torch.Size([1, 1, 512, 512])

        noise_input = self.conv_z(noise_input)
        # y = torch.sigmoid(self.conv_y(y_input))

        y = torch.sigmoid(self.conv_y(y_input))
        x = torch.cat([noise_input, y], dim=1)

        # 连接x和 y 作为输入
        x = self.act(self.bn(self.conv_3(x)))

        # 第一个瓶颈残差块 (BRB)
        identity1 = x
        x = self.act(self.bn(self.conv3_1(x)))
        x = self.conv3_2(x)
        x1 = x + self.conv1(identity1)
        x1 = self.act(self.bn(x1))

        # 第二个瓶颈残差块 (BRB)
        identity2 = x1
        x = self.act(self.bn(self.conv3_3(x1)))
        x = self.conv3_4(x)
        x2 = x + self.conv1(identity2)
        x2 = self.act(self.bn(x2))

        # 第三个瓶颈残差块 (BRB)
        identity3 = x2
        x = self.act(self.bn(self.conv3_3(x2)))
        x = self.conv3_2(x)
        x3 = x + self.conv1(identity3)
        x3 = self.act(self.bn(x3))

        # spatial–spectral attention module
        x3 = self.att(x3)

        # 最后输出
        x = self.conv_1(x3)
        x = torch.sigmoid(x)

        return x

 2.4测试

# 创建模型实例
model = hslnet(in_channels=65, out_channels=1, middle_channels=32)

# 定义一个输入张量进行测试
input_tensor = torch.randn(1, 65, 512, 512)  # 假设输入大小为[1, 65, 512, 512]

# 运行模型
output = model(input_tensor)
print("Output shape:", output.shape)

输出结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值