图神经网络学习task03(基于图神经网络的节点表征学习)

本文探讨了图神经网络(GNN)在节点表征学习中的作用,通过对比MLP、GCN和GAT在Cora数据集上的节点分类任务表现,揭示了GNN如何利用节点间连接关系提升预测精度。GCN达到了81.4%的精度,表现出色。实验结果显示,考虑邻接节点信息是GNN优于传统深度神经网络的关键。
摘要由CSDN通过智能技术生成

一、本次打开学习任务3:基于图神经网络的节点表征学习

在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation)。我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征。高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提。
本节中,将学习实现多层图神经网络的方法,并以节点分类任务为例,学习训练图神经网络的一般过程。我们将以Cora 数据集为例子进行说明,Cora 是一个论文引用网络,节点代表论文,如果两篇论文存在引用关系,则对应的两个节点之间存在边,各节点的属性都是一个1433维的词包特征向量。
任务是预测各篇论文的类别(共7类)。
此外,还将对MLP和GCN, GAT(两个知名度很高的图神经网络)三类神经网络在节点分类任务中的表现进行比较分析,以此来展现图神经网络的强大和论证图神经网络强于普通深度神经网络的原因。

二、主要步骤:
获取并分析数据集、构建一个方法(MLP和GCN, GAT)用于分析节点表征的分布。
使用该方法构建实例模型进行训练
进行节点分类。

三、实验结果对比:

模型精度
MLP59%
GCN81.4%
GAT73.8%

具体原理及程序代码可以参考datawhale的组队学习训练网页
四、结语
在节点表征的学习中,MLP神经网络只考虑了节点自身属性,忽略了节点之间的连接关系,它的结果是最差的;
而GCN图神经网络与GAT图神经网络,同时考虑了节点自身信息与周围邻接节点的信息,因此它们的结果都优于MLP神经网络。
也就是说,对周围邻接节点的信息的考虑,是图神经网络由于普通深度神经网络的原因。

五、作业
参照这份代码使用PyG中不同的图卷积模块在PyG的不同数据集上实现节点分类或回归任务。

使用SAGEGCN模型,主要代码如下:

from torch_geometric.nn import SAGEConv

class SAGE(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(SAGE, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = SAGEConv(dataset.num_features, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = SAGE(hidden_channels=16)
print(model)

识别精度为79.7%

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值