【深度学习】PyTorch框架(6):GNN图神经网络理论和实践
1.引言
在本文中,我们将探讨图神经网络(GNNs)在图上的应用。近年来,图神经网络在社交网络、知识图谱、推荐系统和生物信息学等多个领域中越来越受到关注。尽管GNNs背后的理论和数学可能初看之下颇为复杂,但其模型的实现却相对简单,有助于我们深入理解其方法论。因此,本文将重点介绍GNN的基本网络层的实现,包括图卷积和注意力层。最终,我们将演示如何在节点级别、边级别和图级别任务中应用GNN。
首先,我们从导入常用的库开始。我们将使用PyTorch Lightning,这在之前的文章中已经有所涉及。
## 导入标准库
import os
import json
import math
import numpy as np
import time
# 导入绘图相关库
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # 用于导出
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()
# 导入进度条库
from tqdm.notebook import tqdm
# 导入PyTorch库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# 导入Torchvision库
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# 导入PyTorch Lightning库
try:
import pytorch_lightning as pl
except ModuleNotFoundError: # 如果Google Colab未预装PyTorch Lightning,则在此安装
!pip install --quiet pytorch-lightning>=1.4
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
# 设置数据集下载路径(例如CIFAR10)
DATASET_PATH = "../data"
# 设置预训练模型保存路径
CHECKPOINT_PATH = "../saved_models/tutorial7"
# 设置随机种子
pl.seed_everything(42)
# 确保在GPU上的所有操作都是确定性的(如果使用)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 判断并设置设备
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)
接下来,我们将下载一些预训练的模型。
## 下载预训练模型
import urllib.request
from urllib.error import HTTPError
# 预训练模型存储的GitHub URL
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# 需要下载的文件列表
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]
# 如果检查点路径不存在,则创建它
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
# 对于每个文件,检查它是否已经存在。如果没有,尝试下载它。
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name)
if "/" in file_name:
os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
if not os.path.isfile(file_path):
file_url = base_url + file_name
print(f"正在下载 {
file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path)
except HTTPError as e:
print("下载过程中出现问题。请尝试从其他渠道获取文件,或联系作者并提供完整的错误信息:\n", e)
2.图神经网络
2.1.图的表示方法
在深入探讨图上的特定神经网络操作之前,我们首先需要了解如何表示一个图。在数学上,图 G \mathcal{G} G 可以定义为一个包含节点/顶点集合 V V V 和边/链接集合 E E E 的元组: G = ( V , E ) \mathcal{G}=(V,E) G=(V,E)。每条边由两个顶点组成,代表它们之间的连接。例如,考虑以下图:
图中的顶点为 V = { 1 , 2 , 3 , 4 } V=\{1,2,3,4\} V={ 1,2,3,4},边为 E = { ( 1 , 2 ) , ( 2 , 3 ) , ( 2 , 4 ) , ( 3 , 4 ) } E=\{(1,2), (2,3), (2,4), (3,4)\} E={(1,2),(2,3),(2,4),(3,4)}。为了简化,我们假设图是无向的,因此不包括反向边如 ( 2 , 1 ) (2,1) (2,1)。在实际应用中,顶点和边通常具有特定的属性,边甚至可以是有向的。关键在于我们如何高效地以矩阵操作的方式表示这种多样性。通常,对于边的表示,我们可以选择邻接矩阵或顶点对索引列表。
邻接矩阵 A A A 是一个方阵,其元素表示顶点对是否相邻,即是否相连。在最简单情况下,如果节点 i i i 到 j j j 有连接,则 A i j A_{ij} Aij 为 1,否则为 0。如果图中有边的属性或不同类别的边,这些信息也可以添加到矩阵中。对于无向图, A A A 是一个对称矩阵( A i j = A j i A_{ij}=A_{ji} Aij=Aji)。以上述图为例,其邻接矩阵如下:
A = [ 0 1 0 0 1 0 1 1 0 1 0 1 0 1 1 0 ] A = \begin{bmatrix} 0 & 1 & 0 & 0\\ 1 & 0 & 1 & 1\\ 0 & 1 & 0 & 1\\ 0 & 1 & 1 & 0 \end{bmatrix} A=
0100101101010110
虽然以边的列表形式表达图在内存和计算上更为高效,但使用邻接矩阵更直观且易于实现。在下面的实现中,我们将使用邻接矩阵以简化代码。然而,常用的库可能会使用边列表,我们将在后续讨论中更详细地探讨这一点。
另外,我们也可以通过边的列表定义一个稀疏邻接矩阵,这样我们可以像处理密集矩阵一样进行操作,但更节省内存。PyTorch 通过其 torch.sparse
子包支持这一点(文档),但请注意,这仍处于测试阶段(API 可能在未来发生变化)。
2.2.图卷积
图卷积网络(GCNs)由 [Kipf 等人]在 2016 年在阿姆斯特丹大学提出。他还撰写了一篇关于该主题的优秀 [博客文章],如果你希望从不同角度了解 GCNs,这篇文章是推荐的阅读。GCNs 与图像中的卷积类似,在于“滤波器”参数通常在图中的所有位置共享。同时,GCNs 依赖于消息传递方法,即顶点与邻居交换信息,并向彼此发送“消息”。在深入数学表达之前,我们先尝试直观理解 GCNs 的工作原理。首先,每个节点创建一个特征向量,代表它想要发送给所有邻居的消息。其次,这些消息被发送到邻居,使得每个节点从每个相邻节点接收到一条消息。以下是我们示例图的两个步骤的可视化:
如果我们想用更数学的语言来描述,首先需要决定如何合并一个节点接收到的所有消息。由于不同节点接收到的消息数量不同,我们需要一种适用于任何数量的操作。通常,我们选择求和或取平均。给定节点的先前特征 H ( l ) H^{(l)} H(l),GCN 层定义如下:
H ( l + 1 ) = σ ( D ^ − 1 / 2 A ^ D ^ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma\left(\hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}H^{(l)}W^{(l)}\right) H(l+1)=σ(D^−1/2A^D^−1/2H(l)W(l))
W ( l ) W^{(l)} W(l) 是我们将输入特征转换为消息的权重参数。我们在邻接矩阵 A A A 中加入单位矩阵,以便每个节点也将自己的消息发送给自己: A ^ = A + I \hat{A}=A+I A^=A+I。最后,为了计算平均值而不是求和,我们计算对角矩阵 D ^ \hat{D} D^,其中 D i i D_{ii} Dii 表示节点 i i i 拥有的邻居数量。 σ \sigma σ 代表任意激活函数,不一定是 sigmoid(通常在 GNNs 中使用基于 ReLU 的激活函数)。
在 PyTorch 中实现 GCN 层时,我们可以利用张量上的灵活操作。我们不需要定义矩阵 D ^ \hat{D} D^,我们可以在之后简单地将总和消息除以邻居的数量。此外,我们用线性层替换权重矩阵,这也允许我们添加偏置。以下是 PyTorch 模块中 GCN 层的定义:
class GCNLayer(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.projection = nn.Linear(c_in, c_out) # 定义线性层
def forward(self, node_feats, adj_matrix):
"""
前向传播函数
:param node_feats: 节点特征张量,形状为 [batch_size, num_nodes, c_in]
:param adj_matrix: 邻接矩阵的批次,形状为 [batch_size, num_nodes, num_nodes]
:return: 节点的新特征
&#