基于通道注意力机制的高光谱图像分类

本文探讨了高光谱图像分类任务,基于3D-2D卷积的HybridSN网络,并介绍了如何引入通道注意力机制SELayer进行改进。通过训练和测试,模型的准确率从95.71%提升至97.34%,验证了注意力机制的有效性。
摘要由CSDN通过智能技术生成


前言

近年来,由于高光谱数据的独特性质以及所包含的海量信息,对于高光谱图像的分析与处理已经成为遥感影像研究领域的热点之一,而其中的高光谱图像分类任务又对地质勘探、农作物检测、国防军事等领域起着实质性的重要作用,值得更加深入的研究。然而高光谱图像分类任务中,数据特征的获取和学习一直是研究的重点与难点,如何提取充分有效的特征直接影响到分类结果的好坏。

文章参考了论文《HybridSN: Exploring 3-D–2-DCNN Feature Hierarchy for Hyperspectral Image Classification》后,这篇论文构建了一个混合网络解决高光谱图像分类问题,首先用 3D卷积,然后使用 2D卷积。

在此基础上,文章尝试加入通道注意力机制,得到了更好的效果。


步骤

  • 定义网络和创建数据集
  • 训练与测试
  • 基于通道注意力机制改进网络
  • 再次训练和测试,并进行对比

Python 实现

1.定义HybridSN类

  • 模型的网络结构为如下图所示:
    在这里插入图片描述

三维卷积部分:

  • conv1:(1, 30, 25, 25), 8个 7x3x3 的卷积核 ==>(8, 24, 23, 23)
  • conv2:(8, 24, 23, 23), 16个 5x3x3 的卷积核 ==>(16, 20, 21, 21)
  • conv3:(16, 20, 21, 21),32个 3x3x3 的卷积核 ==>(32, 18, 19, 19)

接下来要进行二维卷积,因此把前面的 32*18 reshape 一下,得到 (576, 19, 19)

二维卷积:(576, 19, 19) 64个 3x3 的卷积核,得到 (64, 17, 17)

接下来是一个 flatten 操作,变为 18496 维的向量,

接下来依次为256,128节点的全连接层,都使用比例为0.4的 Dropout,

最后输出为 16 个节点,是最终的分类类别数。

  • 下面是HybridSN网络层数结构表:
    在这里插入图片描述

  • 下面是HybridSN网络定义:

class HybridSN(nn.Module):

    def __init__(self):

        super(HybridSN, self).__init__()

        self.conv_3d = nn.Sequential(
            nn.Conv3d(1, 8, (7, 3, 3)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(8, 16, (5, 3, 3)),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv3d(16, 32, (3, 3, 3)),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.conv_2d = nn.Sequential(
            nn.Conv2d(576, 64, (3, 3)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.linear = nn.Sequential(
            nn.Linear(18496, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, CLASS_NUM),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):

        x = self.conv_3d(x)
        x = x.view(-1, x.shape[1] * x.shape[2], x.shape[3], x.shape[4])
        x = self.conv_2d(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)

        return x
  • 测试一下网络结构
def test_net():

    # 随机输入,测试网络结构是否通
    x = torch.randn(1, 1, 30, 25, 25)
    net = HybridSN()
    y = net(x)
    print(y.shape)

得到结果如下,可见网络结构是正常的
在这里插入图片描述

2.创建数据集

首先对高光谱数据实施PCA降维;然后创建 keras 方便处理的数据格式;然后随机抽取 10% 数据做为训练集,剩余的做为测试集。

  • 首先定义基本函数:
# 对高光谱数据 x 应用 PCA 变换
def apply_pca(x, num_components):

    new_x = np.reshape(x, (-1, x.shape[2]))
    pca = PCA(n_components=num_components, whiten=True)
    new_x = pca.fit_transform(new_x)
    new_x = np.reshape(new_x, (x.shape[0], x.shape[1], num_components))

    return new_x


# 对单个像素周围提取 patch 时,边缘像素就无法取了,因此,给这部分像素进行 padding 操作
def padding_with_zeros(x, margin=2):

    new_x = np.zeros((x.shape[0] + 2 * margin, x.shape[1] + 2 * margin, x.shape[2]))
    x_offset = margin
    y_offset = margin
    new_x[x_offset:x.shape[0] + x_offset, y_offset:x.shape[1] + y_offset, :] = x

    return new_x


# 在每个像素周围提取 patch ,然后创建成符合 keras 处理的格式
def create_image_cubes(x, y, window_size=5, remove_zero_labels=True):

    # 给 x 做 padding
    margin = int((window_size - 1) / 2)
    zero_padded_x = padding_with_zeros(x, margin=margin)
    # split patches
    patches_data = np.zeros((x.shape[0] * x.shape[1], window_size, window_size, x.shape[2]))
    patches_labels = np.zeros((x.shape[0] * x.shape[1]))
    patch_index = 0
    for r in range(margin, zero_padded_x.shape[0] - margin):
        for c in range(margin, zero_padded_x.shape[1] - margin):
            patch = zero_padded_x[r - margin:r + margin + 1, c - margin:c + margin + 1]
            patches_data[patch_index, :, :, :] = patch
            patches_labels[patch_index] = y[r-margin, c-margin]
            patch_index += 1
    if remove_zero_labels:
        patches_data = patches_data[patches_labels > 0, :, :, :]
        patches_labels = patches_labels[patches_labels >
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值