【机器学习】训练GNN图神经网络模型进行节点分类

1. 引言

1.1 图神经网络GNN概述

图神经网络(Graph Neural Network,GNN)是一种专门用于处理图结构数据的神经网络方法。它起源于2005年,当时Gori等人首次提出了GNN的概念,用于学习图中的节点特征以及它们之间的关系。随后,随着深度学习技术的快速发展,GNN得到了广泛的关注和研究。

1.1.1 GNN的核心算法思想

GNN的核心思想是通过迭代地聚合节点的邻居信息来更新节点的表示,从而捕获图的结构信息。这种聚合过程可以看作是一种特殊的图卷积操作,使得GNN能够有效地处理图数据,并提取出节点和边之间的关系特征。主流的GNN算法包括图卷积神经网络(GCN)、图自编码器(Graph Autoencoder)、图生成网络(Graph Generative Network)等。

1.1.2 GNN的应用场景

GNN的应用场景非常广泛,包括社交网络分析、推荐系统、生物信息学、交通流预测等。例如,在社交网络中,GNN可以用于分析用户之间的关系,预测用户的兴趣和行为;在推荐系统中,GNN可以利用用户-物品交互图来提供个性化的推荐;在生物信息学中,GNN可以用于分析蛋白质-蛋白质相互作用网络,预测蛋白质的功能和性质。

1.1.3 GNN图神经网络面临的挑战

尽管GNN在处理图数据方面取得了显著的进展,但仍面临一些挑战。首先,大规模图数据的处理是一个难题,需要设计高效的GNN架构和训练算法来应对。其次,动态图数据的处理也是一个挑战,因为图的结构和节点属性可能会随时间发生变化。此外,GNN在处理异构图(即节点和边具有不同类型和属性的图)时也需要进一步的研究。

未来,GNN的研究将继续关注提高模型的性能、扩展应用范围以及解决上述挑战。随着深度学习技术的不断进步和计算能力的提升,GNN有望在更多领域发挥重要作用,并推动图数据分析和应用的进一步发展。

1.2 节点分类

节点分类(Node Classification)是图数据分析中的一个重要任务,其目标是根据图的结构信息和节点的属性特征来预测图中节点的类别标签。在图数据中,节点通常表示实体,而边则表示实体之间的关系。节点分类在许多领域都有广泛的应用,如社交网络分析、生物信息学、推荐系统等。

1.2.1 节点分类任务的步骤

在图神经网络(Graph Neural Networks, GNNs)的框架下,节点分类通常通过以下步骤实现:

  • 图数据表示:首先,需要将图数据转化为神经网络可以处理的表示形式。这通常包括节点特征矩阵(用于描述节点的属性信息)和邻接矩阵(用于描述节点之间的关系)。

  • 图神经网络层:然后,使用图神经网络层来聚合节点的邻居信息。这些层通过特定的图卷积操作或图注意力机制来更新节点的表示,从而捕获图的结构信息。不同的GNN架构(如Graph Convolutional Networks, GraphSAGE, Graph Attention Networks等)具有不同的聚合函数和更新规则。

  • 特征传播与聚合:在图神经网络中,特征传播是一个关键步骤。通过迭代地聚合节点的邻居信息,每个节点的表示都会逐渐融合其局部邻域的信息。这种过程可以重复多次,以便捕获更广泛的图结构信息。

  • 节点分类:最后,将聚合后的节点表示输入到分类器(如全连接层、softmax层等)中进行分类。分类器会根据节点的表示预测其所属的类别标签。

节点分类任务中,GNNs的优势在于它们能够捕获图数据的复杂依赖关系,并利用这些信息来提高分类性能。与传统的机器学习方法相比,GNNs能够更好地处理图数据的非欧几里得结构,并考虑节点之间的连接关系。

1.2.2 节点分类面临的挑战

然而,节点分类也面临一些挑战。首先,当图的规模非常大时,GNNs的训练和推理过程可能会变得非常耗时。其次,对于动态图数据,GNNs需要能够处理节点和边的添加、删除以及属性变化等情况。此外,当图中的节点和边具有不同的类型和属性时,GNNs也需要具备处理异构图的能力。

展望将来,节点分类的研究将继续关注提高GNNs的性能、扩展应用范围以及解决上述挑战。随着深度学习技术的不断进步和计算能力的提升,GNNs有望在更多领域发挥重要作用,并推动图数据分析和应用的进一步发展。

1.3 GNN模型用于节点分类

图神经网络(GNN)在节点分类任务中发挥着关键作用。节点分类的目标是根据图的结构信息和节点的属性特征来预测图中每个节点的类别标签。GNN模型通过迭代地更新节点的表示来捕获节点之间的相互作用和依赖关系,进而实现高效的节点分类。

1.3.1 GNN进行节点分类的原理

GNN模型的工作原理在于通过聚合节点的邻居信息来更新节点的表示。首先,每个节点的表示向量被初始化为其初始特征向量。然后,在信息传递阶段,GNN通过图卷积操作或图注意力机制来聚合节点的邻居节点的信息,并更新节点的表示。这一过程中,邻居节点的信息被加权求和或聚合,以反映它们对目标节点的影响。通过多次迭代信息传递和聚合邻居信息的过程,GNN能够逐步更新节点的表示,使其包含更丰富的图结构信息。

1.3.2 GNN的优势及应用

GNN模型在节点分类任务中的优势在于其能够捕获图数据中的复杂依赖关系和结构信息,这对于节点分类任务至关重要。同时,GNN模型具有灵活性和可扩展性,能够处理具有不同大小和结构的图数据,并适应大规模图数据的处理需求。

GNN在节点分类任务中的应用广泛,涵盖了社交网络分析、生物信息学和推荐系统等领域。在社交网络中,GNN可以识别用户所属的潜在角色或群体,并根据兴趣对用户进行分类。在生物信息学中,GNN可以应用于分子结构的分类,预测分子的功能和性质。在推荐系统中,GNN可以处理用户-项目交互的图结构数据,预测用户对未知物品的喜好程度,并向用户推荐合适的项目。

随着深度学习技术的不断进步和计算能力的提升,GNN有望在更多领域发挥重要作用,并推动图数据分析和应用的进一步发展。

2. GNN模型实现节点分类的过程

许多在各种机器学习(ML)应用中的数据集,其实体之间存在结构关系,这些关系可以表示为图。这类应用包括社交和通信网络分析、交通预测以及欺诈检测。图表示学习旨在构建和训练用于图数据集的模型,以便用于各种机器学习任务。

本文的例子展示了一个图神经网络(GNN)模型的简单实现。该模型用于Cora数据集上的节点预测任务,以根据论文的词汇和引用网络来预测论文的主题。

本文我们从头开始实现了一个图卷积层,以便更好地理解它们是如何工作的。然而,也有许多基于TensorFlow的专用库提供了丰富的GNN API,例如Spectral、StellarGraph和GraphNets等。

2.1 设置

import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

2.2 准备下载

以下的段代码是用于下载和解压缩Cora数据集的Python脚本:

  1. 下载数据集
    - 使用keras.utils.get_file函数来下载数据集。这个函数是Keras提供的,用于下载文件并保存到本地路径。
    - fname="cora.tgz": 指定下载文件的名称为cora.tgz
    - origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz": 指定数据集的URL来源,即数据集在互联网上的位置。
    - extract=True: 表示下载完成后自动解压缩文件。

  2. 设置数据目录
    - data_dir = os.path.join(os.path.dirname(zip_file), "cora"): 这行代码设置了解压缩后数据的存放目录。
    - os.path.dirname(zip_file): 获取zip_file变量的目录路径,即下载文件的存放路径。
    - os.path.join(..., "cora"): 将下载文件的目录路径与"cora"字符串连接,形成完整的数据集目录路径。

代码的作用是确保Cora数据集被下载并解压缩到程序可以访问的目录中。Cora数据集通常用于图神经网络的节点分类任务,包含了论文的引用关系和内容信息。在机器学习或深度学习项目中,这样的数据准备步骤是常见的,以便于后续的数据加载和处理。

zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")

2.3 数据预处理

2.3.1 加载并可视化数据集

处理并可视化数据集代码是用来加载和处理Cora数据集中的引用信息:

  1. 加载数据
    - 使用pd.read_csv函数从Pandas库来读取CSV文件。这个函数是用来读取CSV文件并将其转换成DataFrame对象的。

  2. 指定文件路径
    - os.path.join(data_dir, "cora.cites"): 这行代码通过os.path.join函数来拼接数据集的基础目录data_dir和子目录"cora.cites",形成完整的文件路径。data_dir是在上一段代码中设置的Cora数据集的目录路径。

  3. 设置分隔符和列名
    - sep="\t": 指定分隔符为制表符(\t),这意味着CSV文件中的列是以制表符分隔的。
    - header=None: 指定文件中没有头部信息,即列名不在文件的第一行中。
    - names=["target", "source"]: 指定列名,即使文件没有头部信息,这里明确了两列的名称分别为"target""source"

  4. 打印引用数据的形状
    - print("Citations shape:", citations.shape): 打印出DataFrame citations 的形状,即行数和列数。这通常用于检查数据加载是否正确,以及理解数据集的规模。

引用数据集citations通常包含了两列,一列是被引用论文的ID(target),另一列是引用它的论文的ID(source)。在图神经网络中,这些引用可以表示为图中的边,其中每条边连接了两个节点(论文),用于捕捉论文之间的相互关系。

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)
print("Citations shape:", citations.shape)
2.3.2 展示数据框

下面的代码将展示引文(citations)数据框的一个样本。目标列包含了由源列中的论文ID所引用的论文ID。

在Pandas中,citations.sample(frac=1).head()这行代码执行了以下操作:

  1. citations.sample(frac=1): 这是一个方法调用,作用是从citations DataFrame中随机抽取记录。参数frac=1表示抽取全部的记录(即100%的记录)。如果不指定frac或者设置frac<1,则会按照给定的比例随机抽取记录。

  2. .head(): 这个方法调用返回采样结果的前n行,其中n默认为5。也就是说,它会返回一个DataFrame的前5行。这是查看DataFrame内容的一个快速方法,特别是当你不想查看整个数据集时。

综合来看,citations.sample(frac=1).head()这行代码的作用是:从citations DataFrame中随机抽取全部记录,然后返回这些记录的前5行。这通常用于获取数据集的一个随机样本,以便进行快速检查或展示数据的多样性。

请注意,每次执行这个操作时,由于是随机采样,返回的5行可能会不同。如果你希望每次采样都得到相同的结果,可以在调用sample方法时指定随机种子,例如citations.sample(frac=1, random_state=1).head()

citations.sample(frac=1).head()
2.3.3 加载数据到Pandas DataFrame

现在,让我们将论文数据加载到Pandas DataFrame中。

首先创建了一个名为column_names的列表,它包含了Cora数据集中论文内容文件的所有列名。第一列是paper_id,接下来的1433列用term_0term_1432命名,代表论文中术语的存在与否,最后一列是subject,表示论文的主题。

然后使用pd.read_csv函数读取Cora数据集的cora.content文件,这个文件包含了论文的内容信息。通过os.path.join函数构建文件的完整路径,sep参数指定了制表符作为字段的分隔符,header参数设置为None表示文件中没有提供列名的头部行,names参数用来指定前面创建的列名列表。

最后,使用print函数打印出加载后的papers DataFrame的形状,即它包含的行数和列数,这有助于了解数据集的规模。

column_names = ["paper_id"] + [f"term_{
     idx}" for idx in range(1433)] + ["subject"]
papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"), sep="\t", header=None, names=column_names,
)
print("Papers shape:", papers.shape)
Papers shape: (2708, 1435)
2.3.4 展示Pandas DataFrame

展示论文DataFrame的一个样本。该DataFrame包括paper_id和subject列,以及1,433个二进制列,这些列表示相应术语是否存在于论文中。

print(papers.sample(5).T)

显示每个主题的论文数量

print(papers.subject.value_counts())
Neural_Networks           818
Probabilistic_Methods     426
Genetic_Algorithms        418
Theory                    351
Case_Based                298
Reinforcement_Learning    217
Rule_Learning             180
Name: subject, dtype: int64
2.3.5 创建数据索引和转换

索引和转换的目的是对Cora数据集中的论文和引用信息进行预处理,使其适合用于图神经网络模型的训练。

  • 主题索引的创建:首先,代码提取subject列中所有独特的主题,并按字典顺序对它们进行排序,生成一个有序列表class_values

  • 主题到索引的映射:接着,通过枚举排序后的主题列表,创建一个字典class_idx,该字典将每个主题映射到一个唯一的索引,这个索引将用于后续模型训练中的主题表示。

  • 论文索引的创建:类似地,代码对paper_id列中的所有独特论文ID进行排序,并创建一个字典paper_idx,将排序后的论文ID映射到它们各自的索引。

  • 论文ID的转换:在papers数据集中,使用paper_idx字典将paper_id列中的论文ID转换为对应的数值索引。

  • 引用数据的转换:在citations数据集中,使用paper_idx字典将source(引用源)和target(引用目标)列中的论文ID转换为数值索引。

  • 主题的转换:最后,使用class_idx字典将papers数据集中的subject列中的主题名称转换为对应的数值索引。

通过这些步骤,原始数据集中的文本标识符被转换为模型易于处理的数值形式,为后续的图神经网络模型训练做好了准备。这种转换有助于模型更高效地学习论文之间的引用关系以及它们对应的主题分类。

class_values = sorted(papers["subject"].unique())
class_idx = {
   name: id for id, name in enumerate(class_values)}
paper_idx = {
   name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
2.3.6 可视化引用图

代码用于可视化Cora数据集的引用网络,具体步骤如下:

  • 设置图形大小:使用plt.figure函数创建一个新的图形对象,并设置图形的大小为10x10英寸,以确保有足够的空间来展示网络图。

  • 准备颜色列表:从papers DataFrame中提取subject列的所有唯一值,将其转换为列表colors。这个列表将用于为图中的节点着色,每个主题的论文将使用不同的颜色。

  • 生成图数据结构:利用networkx库的from_pandas_edgelist函数,从citations DataFrame中随机抽取1500条引用记录来构建图cora_graph。这个图的节点代表论文,边代表论文之间的引用关系。

  • 筛选节点主题:从papers DataFrame中筛选出与cora_graph图中节点对应的论文主题列表subjects。这确保了图中的每个节点都能根据其主题着上正确的颜色。

  • 绘制网络图:使用networkxdraw_spring函数绘制图cora_graph,其中node_size参数设置为15,控制节点的大小。node_color参数设置为subjects列表,这样每个节点就会根据其主题被着上不同的颜色。

通过这种方式,代码生成了一个视觉化的网络图,其中节点代表论文,节点之间的连线代表引用关系,节点的颜色表示论文的主题。这种可视化有助于直观地理解数据集中论文之间的相互关系及其主题分布。

plt.figure(figsize=(10, 10))
colors = papers["subject"].tolist()
cora_graph = nx.from_pandas_edgelist(citations.sample(n=1500))
subjects = list(papers[papers["paper_id"].isin(list(cora_graph.nodes))]["subject"])
nx.draw_spring(cora_graph, node_size=15, node_color=subjects)

图中的每个节点代表一篇论文,节点的颜色对应其主题。请注意,我们仅显示数据集中论文的一个样本。
在这里插入图片描述

2.3.7 数据分割

将数据集划分为分层的训练集和测试集
代码的目的是将Cora数据集分割为训练集和测试集,具体步骤如下:

  • 初始化空列表:首先,初始化两个空列表train_datatest_data,用于存储训练集和测试集的数据。

  • 分组并随机选择:使用papers.groupby("subject")对数据集中的论文按主题进行分组。对于每个主题组,使用np.random.rand生成一个与该组论文数量相同的随机数数组。通过比较这些随机数与0.5,决定每篇论文是否被选入训练集(小于或等于0.5的被选入训练集,其余的被选入测试集)。

  • 分配训练集和测试集:对于每个主题组,根据随机选择的结果,将论文分配到train_datatest_data列表中。

  • 合并数据:使用pd.concat函数将train_datatest_data列表中的所有分组数据合并成两个单独的DataFrame,分别代表整个训练集和测试集。

  • 随机打乱数据:使用.sample(frac=1)确保训练集和测试集中的论文是随机分布的,frac=1表示打乱全部数据。

  • 打印数据形状:最后,打印出训练集和测试集的形状,即它们各自的行数和列数,以验证数据集的规模。

通过这种方式,代码

评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值