原文:How Does DGL Represent A Graph?
DGL笔记1——用DGL表示图
DGL笔记2——用DGL识别节点
DGL笔记3——自己写一个GNN模型
GNN 对于很多图机器学习任务来说是一个很强大的工具。这篇文章我们会学习使用 GNN 进行节点分类的工作流程,也就是“如何为节点分类”。
学习完本文,我们可以
- 加载 DGL 自带的数据集。
- 使用 DGL 提供的神经网络模块构建 GNN 模型。
- 在 CPU 或 GPU 上训练用于节点分类的 GNN 模型,并且评估效果。
当然这都建立在你已经有 PyTorch 的使用经验上。
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
GNN与节点分类概述
应该说基于图数据的最常见任务就是节点分类了。其中,我们训练的模型需要预测每个节点的真实类别。在图神经网络出现之前,很多方法要么单独使用连通性(比如 DeepWalk 或者 node2vec),要么使用连通性和节点自身特征的简单组合。相比之下,GNN 提供了一种通过结合连通性和 局部邻域(local neighborhood) 特征来表示节点的办法。
Kipf 等人提出的一个将节点分类问题表述为半监督节点分类任务就是一个典型例子。仅仅需要一小部分标记的节点,图神经网络(GNN)就可以准确地预测其他节点的类别。
接下来我们将展示如何使用 Cora 数据集上的少量标签构建这样一个用于半监督节点分类的 GNN,这是一个以“论文”为节点、以“引用”为边的引文网络。GNN 的任务是预测给定论文的类别。每个论文节点都包含一个词计数向量(Word count vector)作为其特征,并进行了归一化处理,使它们总和为 1,如论文第 5.2 节所述。
读取
import dgl.data
dataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
Out:
Downloading /Users/david/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to /Users/david/.dgl/cora_v2
Finished data loading and preprocessing.
NumNodes: 2708
NumEdges: 10556
NumFeats: 1433
NumClasses: 7
NumTrainingSamples: 140
NumValidationSamples: 500
NumTestSamples: 1000
Done saving data into cached files.
Number of categories: 7
首先会下载这个数据集,然后打印出这个数据集的种类。我们可以看到这个数据集的基本信息:
- 2708个节点(nodes)
- 10056条边(edges)
- 1433个特征(feats)
- 7个类型(classes)
一个 DGL 自带数据集可能包括一个或者多个图,今天我们使用的 Cora 数据集只有一个图。
g = dataset[0]
print(g)
Out:
Graph(num_nodes=2708, num_edges=10556,
ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32),
'label': Scheme(shape=(), dtype=torch.int64),
'val_mask': Scheme(shape=(), dtype=torch.bool),
'test_mask': Scheme(shape=(), dtype=torch.bool),
'train_mask': Scheme(shape=(), dtype=torch.bool)}
edata_schemes={})