一、本次打开学习任务3:基于图神经网络的节点表征学习
在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation)。我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征。高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提。
本节中,将学习实现多层图神经网络的方法,并以节点分类任务为例,学习训练图神经网络的一般过程。我们将以Cora 数据集为例子进行说明,Cora 是一个论文引用网络,节点代表论文,如果两篇论文存在引用关系,则对应的两个节点之间存在边,各节点的属性都是一个1433维的词包特征向量。
任务是预测各篇论文的类别(共7类)。
此外,还将对MLP和GCN, GAT(两个知名度很高的图神经网络)三类神经网络在节点分类任务中的表现进行比较分析,以此来展现图神经网络的强大和论证图神经网络强于普通深度神经网络的原因。
二、主要步骤:
获取并分析数据集、构建一个方法(MLP和GCN, GAT)用于分析节点表征的分布。
使用该方法构建实例模型进行训练
进行节点分类。
三、实验结果对比:
模型 | 精度 |
---|---|
MLP | 59% |
GCN | 81.4% |
GAT | 73.8% |
具体原理及程序代码可以参考datawhale的组队学习训练网页:
四、结语
在节点表征的学习中,MLP神经网络只考虑了节点自身属性,忽略了节点之间的连接关系,它的结果是最差的;
而GCN图神经网络与GAT图神经网络,同时考虑了节点自身信息与周围邻接节点的信息,因此它们的结果都优于MLP神经网络。
也就是说,对周围邻接节点的信息的考虑,是图神经网络由于普通深度神经网络的原因。
五、作业
参照这份代码使用PyG中不同的图卷积模块在PyG的不同数据集上实现节点分类或回归任务。
使用SAGEGCN模型,主要代码如下:
from torch_geometric.nn import SAGEConv
class SAGE(torch.nn.Module):
def __init__(self, hidden_channels):
super(SAGE, self).__init__()
torch.manual_seed(12345)
self.conv1 = SAGEConv(dataset.num_features, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
model = SAGE(hidden_channels=16)
print(model)
识别精度为79.7%