一文讲懂图神经网络GNN(剧本杀版)

1. 图神经网络(GNN)简介

图神经网络(Graph Neural Networks,GNN)是一类专门设计用于处理图结构数据的深度学习模型。它能够有效地捕捉和利用数据中的关系信息,使其在许多领域中展现出强大的潜力。为了便于各位看官易于理解GNN中节点、边和图的概念,小编这里构造一个虚拟的剧本杀游戏如下:

1.1 故事背景:

在一座古老的庄园里,富豪老爷爷突然离奇死亡。他的遗嘱中提到,只有解开庄园的秘密,才能继承他的遗产。六位与老爷爷有关的人被邀请到庄园中解开谜题。

主要人物:
詹姆斯(老管家)
玛丽(老爷爷的女儿)
托马斯(玛丽的丈夫)
莉莉(老爷爷的孙女)
亚瑟(家族律师)
威廉(老爷爷的私人医生)

1.2 人物关系图:

在这里插入图片描述

1.3 故事情节:

参与者需要在庄园中搜寻线索,解开谜题,同时调查老爷爷的死因。每个人物都有自己的秘密和动机,可能是凶手,也可能是无辜的。玩家需要通过对话、搜证和逻辑推理来揭示真相。

1.4 引入GNN图神经网络

那么在这个故事中,节点、边和图分别表示什么呢?

1.4.1 节点:

在这个剧本杀剧情中,每个人物可以被表示为一个节点,每个节点的特征可以包括:【年龄,与死者的关系,财务状况,性格特征,在案发时的不在场证明】。

1.4.2 边:

人物之间的关系可以用边来表示,边的特征可以包括:【关系类型,关系强度,互动频率,潜在冲突】

1.4.3 图:

在这里插入图片描述

1.4.4 GNN任务:

  • 节点分类:预测每个人物是否可能是凶手
  • 边预测:推测人物之间可能存在的隐藏关系
  • 图分类:根据整体图结构判断案件的类型(谋杀、意外等)
    想必看到这里大家对于GNN中节点、边和图这几个词建立了一些概念,那么接下来我们详细讲解一下什么是GNN。

2. 为什么引入 GNN

首先,大家可能会疑惑,为什么要提出GNN网络呢,现有的CNN、RNN和Transformer等网络难道还不够用吗?

2.1 GNN vs 传统神经网络

  • CNN:擅长处理规则网格数据如图像,操作简单。
  • RNN:适用于序列数据,如文本或时间序列。
  • Transformer:善于处理长距离依赖的序列数据,。

然而,这些传统模型在处理图结构数据时存在局限性。图数据的不规则性和复杂的关系结构使得传统方法难以直接应用。且传统的结构难以捕获数据间的拓扑结构。比如对于一个化学分子结构:
在这里插入图片描述
对于该领域的问题,传统的CNN、RNN或者Transformer根本就不适用,相反GNN能够捕捉数据的拓扑结构信息,其中每个原子就是一个节点,每个化学键就是一个边。

2.2 GNN 还有较为广泛的应用领域:

  • 社交网络分析:预测用户行为,检测社区结构。
  • 推荐系统:基于用户-物品图的个性化推荐。
  • 生物信息学:蛋白质结构预测,药物相互作用分析。
  • 交通流量预测:基于道路网络的交通状况分析。
  • 知识图谱:实体关系推理,知识补全。

3. 图神经网络的原理

3.1 GNN 的基本模块

  • 节点(Nodes):图中的实体,如剧本杀中的角色。
  • 边(Edges):节点之间的关系,如角色间的社交关系。
  • 图(Graph):由节点和边组成的整体结构。

3.2 邻接矩阵

在这里插入图片描述

邻接矩阵是表示图结构的数学工具。对于有 N 个节点的图,邻接矩阵 A 是一个 N×N 的矩阵:
A i j = { 1 , 如果节点 i 和节点 j 之间有边  0 , 其他情况 A_{ij} = \begin{cases} 1, & \text{如果节点 i 和节点 j 之间有边} \ 0, & \text{其他情况} \end{cases} Aij={1,如果节点 i 和节点 j 之间有边 0,其他情况
在这里插入图片描述

3.3 GNN 消息传递

消息传递机制是图神经网络的核心,它允许节点之间交换信息,从而学习到更丰富的特征表示。这个过程通常包括三个主要步骤:消息生成、消息聚合和节点更新。

3.3.1 消息生成

在这个阶段,每个节点会根据自身的特征和与邻居节点的连接关系生成消息:
在这里插入图片描述

3.3.2 消息聚合

节点收集来自其所有邻居的消息,并使用某种聚合函数(如求和、平均或最大值)将这些消息组合起来:
在这里插入图片描述

3.3.3 节点更新

基于聚合的消息和节点自身的当前状态,更新节点的特征表示:
在这里插入图片描述

3.3.4 更新过程

节点 B 和 C 生成发送给 A 的消息
A 聚合来自 B 和 C 的消息
A 基于聚合的消息和自身当前状态更新其特征
在这里插入图片描述

3.4 GNN 实例

考虑以下简单的无向图:
在这里插入图片描述
每个节点初始有一个标量特征值(如A:1表示节点A的初始特征值为1)。

3.4.1 GNN 实例使用的函数

  • 消息函数(MSG):取发送节点和接收节点特征的平均值
  • 聚合函数(AGGREGATE):对所有接收到的消息求和
  • 更新函数(UPDATE):将聚合的消息加到当前节点特征上

3.4.2 第一轮消息传递

  • 消息生成
    对于每条边,计算消息:
    A -> B: MSG(1, 2) = (1 + 2) / 2 = 1.5
    A -> C: MSG(1, 3) = (1 + 3) / 2 = 2
    B -> D: MSG(2, 4) = (2 + 4) / 2 = 3
    C -> D: MSG(3, 4) = (3 + 4) / 2 = 3.5

  • 消息聚合
    每个节点聚合收到的所有消息:

    A: AGGREGATE(1.5, 2) = 3.5
    B: AGGREGATE(1.5) = 1.5
    C: AGGREGATE(2) = 2
    D: AGGREGATE(3, 3.5) = 6.5

  • 节点更新
    更新每个节点的特征:

    A: UPDATE(1, 3.5) = 1 + 3.5 = 4.5
    B: UPDATE(2, 1.5) = 2 + 1.5 = 3.5
    C: UPDATE(3, 2) = 3 + 2 = 5
    D: UPDATE(4, 6.5) = 4 + 6.5 = 10.5

  • 第一轮后的图结构
    在这里插入图片描述

3.4.3 第二轮消息传递

  • 消息生成
    A -> B: MSG(4.5, 3.5) = (4.5 + 3.5) / 2 = 4
    A -> C: MSG(4.5, 5) = (4.5 + 5) / 2 = 4.75
    B -> D: MSG(3.5, 10.5) = (3.5 + 10.5) / 2 = 7
    C -> D: MSG(5, 10.5) = (5 + 10.5) / 2 = 7.75
  • 消息聚合
    A: AGGREGATE(4, 4.75) = 8.75
    B: AGGREGATE(4) = 4
    C: AGGREGATE(4.75) = 4.75
    D: AGGREGATE(7, 7.75) = 14.75
  • 节点更新
    A: UPDATE(4.5, 8.75) = 4.5 + 8.75 = 13.25
    B: UPDATE(3.5, 4) = 3.5 + 4 = 7.5
    C: UPDATE(5, 4.75) = 5 + 4.75 = 9.75
    D: UPDATE(10.5, 14.75) = 10.5 + 14.75 = 25.25
  • 第二轮后的图结构
    在这里插入图片描述

3.4.4 结果分析

  • 信息传播:可以看到,经过两轮消息传递,每个节点的特征值都发生了显著变化。这反映了图中信息的传播过程。
  • 中心性:节点D的特征值增长最快,这反映了它在图中的中心地位(连接度最高)。
  • 特征融合:每个节点的新特征不仅包含了自身的信息,还融合了邻居节点的信息。

4. 实例代码

下面是一个完整的 PyTorch 代码,包含了模型定义、数据加载和训练过程:
首先,我们需要安装必要的库:

pip install torch torch_geometric

完整代码如下:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.classifier = torch.nn.Linear(16, num_classes)

    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.conv2(h, edge_index)
        h = F.dropout(h, p=0.5, training=self.training)
        out = self.classifier(h)
        return out

# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=dataset.num_features, num_classes=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 训练模型
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1:3d}, Loss: {loss.item():.4f}')

# 评估模型
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值