基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)

目录

一、背景

二、基于卷积神经网络的代码实现

1、安装依赖库

2、建立图卷积神经网络

3、建立数据的边

4、训练模型

5、可视化

三、项目代码


一、背景

图卷积神经网络(Graph Convolutional Networks, GCNs)在高光谱图像分类中是一种有效的方法,特别适用于处理具有复杂空间关系的数据。高光谱图像通常包含数百个甚至数千个连续的频谱波段,每个波段对应一个光谱特征,这使得传统的卷积神经网络在处理高光谱图像时面临困难,因为它们无法有效地捕获像素之间的空间关系。

GCNs通过利用图结构来解决这一问题,将像素(或者像素附近的区域)视为图中的节点,并利用这些节点之间的关系进行特征学习和分类。以下是GCNs在高光谱图像分类中的一些关键点和优势:

  1. 图结构建模:将高光谱图像中的像素视为图中的节点,像素之间的空间关系(例如邻近关系)作为图的边,这样就能够在整个图上利用节点的局部和全局信息。

  2. 卷积操作:GCN引入了图卷积操作,允许在图结构上进行类似于传统卷积神经网络中的卷积操作。这种操作可以捕获节点及其邻居的特征,并利用这些信息来提取更有意义的特征表示。

  3. 特征学习:通过多层的图卷积操作,GCNs能够逐步学习出更加抽象和高级的特征表示,这对于高光谱数据的复杂特征提取尤为重要。

  4. 分类器:最后一层通常是一个分类器,用于将学习到的特征映射到类别标签空间,从而进行分类。

  5. 适应性:GCNs在处理高光谱图像时具有很强的适应性和灵活性,能够处理不同大小和分辨率的图像,以及不同数量和配置的频谱波段。

总体来说,图卷积神经网络通过充分利用高光谱图像中像素之间的空间关系,有效地提升了分类性能,并在遥感图像分析和其他高维数据的处理中展现出了广阔的应用前景。

二、基于卷积神经网络的代码实现

下面我们以IP数据集为例子进行展开讲解。

1、安装依赖库
matplotlib==3.3.4
networkx==2.1
numpy==1.19.5
pandas==1.1.5
scikit_learn==1.5.1
scipy==1.5.4
seaborn==0.11.2
spectral==0.22.4
torch==1.7.1+cu110
torch_geometric==2.0.2
tqdm==4.62.3
2、建立图卷积神经网络
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch
import torch.nn as nn

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 32)
        self.conv1_bn_relu = nn.Sequential(
            nn.BatchNorm1d(32),
            nn.ReLU()
        )

        self.conv2 = GCNConv(32, 64)
        self.conv2_bn_relu = nn.Sequential(
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        self.cls = nn.Sequential(
            nn.Linear(64, num_classes),
        )

    def forward(self, edge, data):
        x, edge_index = data, edge
        x = self.conv1_bn_relu(self.conv1(x, edge_index))
        x = self.conv2_bn_relu(self.conv2(x, edge_index))

        return self.cls(x)
3、建立数据的边

首先进行PCA数据降维:

X_pca = applyPCA(X, numComponents=pca_components)

然后将无标签数据进行剔除:

    X_pca = X_pca.reshape(-1,pca_components)
    y = y.ravel()
    mask = y == 0

    # 剔除无标签的数据
    data = X_pca[~mask]
    label = y[~mask]

划分训练验证集(训练70%):

X_train, X_test, y_train, y_test = splitTrainTestSet(range(len(data)),label,trainRatio=0.7)

最后建立所有样本的边(这里取最近邻的样本为3):

Edge_build(data,k=3)

4、训练模型

加载数据和模型:

    X_train_index,X_test_index = utils.create_train_test('./data/'+patch_+'/train_index.txt',
                                                          './data/'+patch_+'/test_index.txt')
    data,label = utils.create_features('./data/'+patch_+'/data.txt',
                                        './data/'+patch_+'/label.txt')
    edge = pd.read_csv('./data/'+patch_+'/edge.txt', sep=" ", header=None).values.T

    # 建立模型
    model = GCN(30, 16)

训练模型:

class Trainer():
    def __init__(self, data,y,edge,X_train_index,X_test_index, model, optimizer, loss_function, epochs):
        self.y = y
        self.edge = torch.from_numpy(edge).type(torch.LongTensor).to(device)
        self.X_train_index = X_train_index
        self.X_test_index = X_test_index
        self.data = torch.from_numpy(data).type(torch.FloatTensor).to(device)
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.epochs = epochs

        self.y_train = torch.from_numpy(y[X_train_index]).type(torch.LongTensor).to(device)
        self.y_test = torch.from_numpy(y[X_test_index]).type(torch.LongTensor).to(device)

        self.preds = None
    def train(self):
        pass

    def test(self):
        self.model.eval()
        pass

    trainer = Trainer(
        data=data,
        y=label,
        edge=edge,
        X_train_index=X_train_index,
        X_test_index=X_test_index,
        model=model,
        optimizer=optim.Adam(model.parameters(), lr=0.001),
        loss_function=nn.CrossEntropyLoss(),
        epochs=1000
    )

    trainer.train()
    trainer.test()

5、可视化
if __name__ == '__main__':
    patch_ = "IP"

    graph, A = utils.create_Graphs_with_attributes_adjadjency_matrix('./data/' + patch_ + '/edge.txt',
                                                                     './data/' + patch_ + '/data.txt')
    data, label = utils.create_features('./data/' + patch_ + '/data.txt',
                                        './data/' + patch_ + '/label.txt')
    edge = pd.read_csv('./data/' + patch_ + '/edge.txt', sep=" ", header=None).values.T

    model = GCN(30, 16)
    model.eval()
    net_params = torch.load("./weight/model.pkl")
    model.load_state_dict(net_params)  # 加载模型可学习参数

    trainer = Trainer(
        data=data,
        y=label,
        edge=edge,
        model=model,
    )

    pred = trainer.pre() + 1

    y_ = sio.loadmat('./data/Indian_pines_gt.mat')['indian_pines_gt']
    a, b = y_.shape
    print('Label shape: ', y_.shape)

    y = y_.ravel()
    mask = y == 0

    outputs = np.zeros_like(y)
    outputs[~mask] = pred

    outputs = outputs.reshape((a, b))

    import spectral
    import matplotlib.pyplot as plt

    predict_image = spectral.imshow(classes=outputs.astype(int), figsize=(5, 5))
    plt.savefig('./results/pre.png', dpi=300)
    plt.pause(1)

三、项目代码

本项目的代码通过以下链接下载:基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
高光谱图像分类是遥感像处理的一个重要分支,在许多领域都有广泛的应用,如农业、林业、环境监测等。GCN(Graph Convolutional Network,卷积网络)是近年来被广泛应用于分类任务的一种神经网络模型,能够有效地提取像中的特征信息,因此也被用于高光谱图像分类中。 GCN高光谱图像分类代码需要进行以下步骤: 1.准备数据集:从公共数据集中下载高光谱像数据集,如Indian Pines数据集,包224x224个像素的224个波段。 2.对原始数据进行预处理:对数据进行标准化处理,将像素值转换为(0,1)之间的范围。 3.构建GCN模型: 使用Python中的keras或tensorflow等深度学习框架,构建GCN模型,包括结构、卷积层、池化层、激活函数等。 4.训练模型:使用训练集对模型进行训练,并使用交叉验证进行调参,找到最佳的超参数。 5.预测分类:使用测试数据集对模型进行预测,并计算预测结果的准确性和精确度。 GCN高光谱图像分类代码需要注意的点包括: 1.在构建GCN模型时需要使用结构,并考虑到的不规则性和稀疏性,适应高光谱像数据集的特点。 2.在训练模型时需要考虑到过拟合的问题,可以使用dropout等技术来避免。 3.预处理的方法要合适,不同的预处理方法可能会对模型的预测结果产生不同影响。 4.需要选择适当的评估指标,如准确性和精确度等。 总之,GCN高光谱图像分类代码需要深入理解卷积网络的原理和高光谱像的特点,充分发挥GCN分类任务中的优势,并在数据预处理、模型构建、训练和预测等方面进行综合考虑才能达到更好的分类结果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

清纯世纪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值