【深度学习】PyTorch框架(6):GNN图神经网络理论和实践

【深度学习】PyTorch框架(6):GNN图神经网络理论和实践

1.引言

在本文中,我们将探讨图神经网络(GNNs)在图上的应用。近年来,图神经网络在社交网络、知识图谱、推荐系统和生物信息学等多个领域中越来越受到关注。尽管GNNs背后的理论和数学可能初看之下颇为复杂,但其模型的实现却相对简单,有助于我们深入理解其方法论。因此,本文将重点介绍GNN的基本网络层的实现,包括图卷积和注意力层。最终,我们将演示如何在节点级别、边级别和图级别任务中应用GNN。
首先,我们从导入常用的库开始。我们将使用PyTorch Lightning,这在之前的文章中已经有所涉及。

## 导入标准库
import os
import json
import math
import numpy as np 
import time

# 导入绘图相关库
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')  # 用于导出
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

# 导入进度条库
from tqdm.notebook import tqdm

# 导入PyTorch库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

# 导入Torchvision库
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

# 导入PyTorch Lightning库
try:
    import pytorch_lightning as pl
except ModuleNotFoundError:  # 如果Google Colab未预装PyTorch Lightning,则在此安装
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# 设置数据集下载路径(例如CIFAR10)
DATASET_PATH = "../data"
# 设置预训练模型保存路径
CHECKPOINT_PATH = "../saved_models/tutorial7"

# 设置随机种子
pl.seed_everything(42)

# 确保在GPU上的所有操作都是确定性的(如果使用)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 判断并设置设备
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

接下来,我们将下载一些预训练的模型。

## 下载预训练模型
import urllib.request
from urllib.error import HTTPError

# 预训练模型存储的GitHub URL
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"

# 需要下载的文件列表
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

# 如果检查点路径不存在,则创建它
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# 对于每个文件,检查它是否已经存在。如果没有,尝试下载它。
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"正在下载 {
     file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("下载过程中出现问题。请尝试从其他渠道获取文件,或联系作者并提供完整的错误信息:\n", e)

2.图神经网络

2.1.图的表示方法

在深入探讨图上的特定神经网络操作之前,我们首先需要了解如何表示一个图。在数学上,图 G \mathcal{G} G 可以定义为一个包含节点/顶点集合 V V V 和边/链接集合 E E E 的元组: G = ( V , E ) \mathcal{G}=(V,E) G=(V,E)。每条边由两个顶点组成,代表它们之间的连接。例如,考虑以下图:

在这里插入图片描述

图中的顶点为 V = { 1 , 2 , 3 , 4 } V=\{1,2,3,4\} V={ 1,2,3,4},边为 E = { ( 1 , 2 ) , ( 2 , 3 ) , ( 2 , 4 ) , ( 3 , 4 ) } E=\{(1,2), (2,3), (2,4), (3,4)\} E={(1,2),(2,3),(2,4),(3,4)}。为了简化,我们假设图是无向的,因此不包括反向边如 ( 2 , 1 ) (2,1) (2,1)。在实际应用中,顶点和边通常具有特定的属性,边甚至可以是有向的。关键在于我们如何高效地以矩阵操作的方式表示这种多样性。通常,对于边的表示,我们可以选择邻接矩阵或顶点对索引列表。

邻接矩阵 A A A 是一个方阵,其元素表示顶点对是否相邻,即是否相连。在最简单情况下,如果节点 i i i j j j 有连接,则 A i j A_{ij} Aij 为 1,否则为 0。如果图中有边的属性或不同类别的边,这些信息也可以添加到矩阵中。对于无向图, A A A 是一个对称矩阵( A i j = A j i A_{ij}=A_{ji} Aij=Aji)。以上述图为例,其邻接矩阵如下:

A = [ 0 1 0 0 1 0 1 1 0 1 0 1 0 1 1 0 ] A = \begin{bmatrix} 0 & 1 & 0 & 0\\ 1 & 0 & 1 & 1\\ 0 & 1 & 0 & 1\\ 0 & 1 & 1 & 0 \end{bmatrix} A= 0100101101010110
虽然以边的列表形式表达图在内存和计算上更为高效,但使用邻接矩阵更直观且易于实现。在下面的实现中,我们将使用邻接矩阵以简化代码。然而,常用的库可能会使用边列表,我们将在后续讨论中更详细地探讨这一点。

另外,我们也可以通过边的列表定义一个稀疏邻接矩阵,这样我们可以像处理密集矩阵一样进行操作,但更节省内存。PyTorch 通过其 torch.sparse 子包支持这一点(文档),但请注意,这仍处于测试阶段(API 可能在未来发生变化)。

2.2.图卷积

图卷积网络(GCNs)由 [Kipf 等人]在 2016 年在阿姆斯特丹大学提出。他还撰写了一篇关于该主题的优秀 [博客文章],如果你希望从不同角度了解 GCNs,这篇文章是推荐的阅读。GCNs 与图像中的卷积类似,在于“滤波器”参数通常在图中的所有位置共享。同时,GCNs 依赖于消息传递方法,即顶点与邻居交换信息,并向彼此发送“消息”。在深入数学表达之前,我们先尝试直观理解 GCNs 的工作原理。首先,每个节点创建一个特征向量,代表它想要发送给所有邻居的消息。其次,这些消息被发送到邻居,使得每个节点从每个相邻节点接收到一条消息。以下是我们示例图的两个步骤的可视化:
添加图片注释,不超过 140 字(可选)
如果我们想用更数学的语言来描述,首先需要决定如何合并一个节点接收到的所有消息。由于不同节点接收到的消息数量不同,我们需要一种适用于任何数量的操作。通常,我们选择求和或取平均。给定节点的先前特征 H ( l ) H^{(l)} H(l),GCN 层定义如下:

H ( l + 1 ) = σ ( D ^ − 1 / 2 A ^ D ^ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma\left(\hat{D}^{-1/2}\hat{A}\hat{D}^{-1/2}H^{(l)}W^{(l)}\right) H(l+1)=σ(D^1/2A^D^1/2H(l)W(l))
W ( l ) W^{(l)} W(l) 是我们将输入特征转换为消息的权重参数。我们在邻接矩阵 A A A 中加入单位矩阵,以便每个节点也将自己的消息发送给自己: A ^ = A + I \hat{A}=A+I A^=A+I。最后,为了计算平均值而不是求和,我们计算对角矩阵 D ^ \hat{D} D^,其中 D i i D_{ii} Dii 表示节点 i i i 拥有的邻居数量。 σ \sigma σ 代表任意激活函数,不一定是 sigmoid(通常在 GNNs 中使用基于 ReLU 的激活函数)。

在 PyTorch 中实现 GCN 层时,我们可以利用张量上的灵活操作。我们不需要定义矩阵 D ^ \hat{D} D^,我们可以在之后简单地将总和消息除以邻居的数量。此外,我们用线性层替换权重矩阵,这也允许我们添加偏置。以下是 PyTorch 模块中 GCN 层的定义:

class GCNLayer(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.projection = nn.Linear(c_in, c_out)  # 定义线性层

def forward(self, node_feats, adj_matrix):
"""
前向传播函数
:param node_feats: 节点特征张量,形状为 [batch_size, num_nodes, c_in]
:param adj_matrix: 邻接矩阵的批次,形状为 [batch_size, num_nodes, num_nodes]
:return: 节点的新特征
&#
GNNPyTorch代码可以如下所示: ```python import torch import torch.nn as nn class GNN(nn.Module): def __init__(self, input_size, hidden_size, out_channels): super(GNN, self).__init__() self.conv = nn.Conv1d(input_size, out_channels, kernel_size=3, padding=1) self.relu = nn.ReLU() self.fc = nn.Linear(hidden_size, out_channels) def forward(self, x): x = self.conv(x) x = self.relu(x) x = torch.max(x, dim=2)\[0\] x = self.fc(x) return x ``` 在这个代码中,我们定义了一个GNN模型,它包含了一个一维卷积层一个全连接层。一维卷积层用于对输入进行空间上的卷积变换,全连接层用于将卷积结果映射到最终的输出维度。在forward方法中,我们首先对输入进行一维卷积操作,然后通过ReLU激活函数进行非线性变换,接着使用最大池化操作获取每个样本的最大值,最后将结果输入到全连接层得到最终的输出。 请注意,这只是一个简单的示例代码,实际应用中可能需要根据具体任务进行适当的修改调整。 #### 引用[.reference_title] - *1* *2* *3* [PyTorch搭建GNN-LSTMLSTM-GNN模型实现多变量输入多变量输出时间序列预测](https://blog.csdn.net/Cyril_KI/article/details/128621012)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值