目录
- 模型结构详解
- 数学原理与推导
- 代表性变体及改进
- 应用场景与优缺点
- PyTorch代码示例
1. 模型结构详解
1.1 核心组件
节点特征 → 邻居聚合 → 特征更新 → 输出
1.1.1 输入输出
- 输入:
- 图结构:邻接矩阵A ∈ {0,1}^{N×N}
- 节点特征:X ∈ R^{N×d}
- 边特征(可选):E ∈ R^{M×c}
- 输出:
- 节点级:Z ∈ R^{N×k}
- 图级:z ∈ R^k
1.1.2 消息传递范式
- 消息生成:对每条边计算消息
- 聚合操作:收集邻居消息
- 状态更新:结合自身特征与聚合消息
1.1.3 激活函数
- 中间层:ReLU/LeakyReLU
- 输出层:Sigmoid(分类)、Linear(回归)
2. 数学原理与推导
2.1 图卷积网络(GCN)公式

其中:
- A~=A+I(添加自连接)
- D~ii=∑jA~ij(度矩阵)
- W(l)为可学习参数
2.2 图注意力网络(GAT)
注意力系数计算:
聚合公式:
2.3 通用框架(消息传递)
数学形式化:
其中:
- ϕ:更新函数
- ψ:消息函数
- ⨁:聚合算子(sum/mean/max)
3. 代表性变体及改进
3.1 卷积类GNN
3.1.1 GraphSAGE
- 改进点:归纳式学习,支持动态图
- 聚合方式:
3.1.2 GIN (Graph Isomorphism Network)
- 理论基础:Weisfeiler-Lehman同构测试
- 聚合公式:
3.2 注意力类GNN
3.2.1 GATv2
- 改进点:动态注意力机制
- 计算优化:
3.2.2 Transformer for Graphs
- 全局注意力:节点与全图交互
3.3 自动编码器类GNN
3.3.1 Graph Autoencoder (GAE)
- 编码器:GCN生成节点嵌入
- 解码器:重构邻接矩阵
3.3.2 VGAE (Variational GAE)
- 引入变分推断:
其中:
4. 应用场景与优缺点
4.1 应用场景
领域 | 任务示例 | 适用模型 |
---|---|---|
社交网络 | 社区发现、影响力预测 | GCN/GAT |
化学分子 | 分子性质预测、药物发现 | GIN/MPNN |
推荐系统 | 用户-商品二部图推荐 | PinSAGE |
交通网络 | 交通流量预测 | ST-GCN |
知识图谱 | 实体链接、关系推理 | RGCN |
4.2 优缺点对比
优点 | 缺点 |
---|---|
处理非欧式数据结构 | 大规模图计算复杂度高 |
捕获拓扑结构与特征关联 | 层数受限(过平滑问题) |
支持节点/边/图多层级任务 | 动态图更新效率低 |
可解释性强(注意力权重可视化) | 异构图建模复杂度高 |
5. PyTorch代码示例
5.1 GCN实现(节点分类)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 使用PyTorch Geometric数据加载
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)
model = GCN(dataset.num_features, 16, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
for epoch in range(200):
loss = train()
print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')
5.2 GAT实现片段
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, heads=8):
super().__init__()
self.conv1 = GATConv(in_dim, hid_dim, heads=heads)
self.conv2 = GATConv(hid_dim*heads, out_dim, heads=1)
def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
核心总结
- 设计哲学:通过消息传递捕捉图结构依赖
- 数学本质:广义邻域特征聚合与非线性变换
- 工程挑战:
- 大规模图的高效计算(采样、分区)
- 动态图实时更新
- 前沿方向:
- 时空图网络(ST-GNN)
- 图对比学习(GraphCL)
- 图结构学习(Jointly Learn Graph)