PyG教程(5):剖析GNN中的消息传播机制

一.前言

众所周知,图神经网络可以从空域或谱域来对其进行研究。其中,空域角度主要借助消息传播机制来构建GNN。本文主要介绍的是消息传递机制,为下篇文章具体介绍PyG中是如何实现消息传播机制做好铺垫。

二.消息传递框架概述

消息传递是Gilmer等在Neural Message Passing for Quantum Chemistry中提出来的从空域角度定义GNN的范式。假设 x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i \in \mathbb{R}^F xi(k1)RF表示节点 i i i在第 k − 1 k-1 k1层的特征 e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD表示节点 j j j到节点 i i i的边上的特征,则消息传播机制可以用如下公式来描述:
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) (1) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) \tag{1} xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))(1)
在消息传播机制中,主要分为三大步骤:消息生成消息聚合消息更新

2.1 消息生成与传播

在本阶段中,每个节点将生成自己的消息,然后向自己的邻居节点”传播“自己的消息,也就是公式(1)中的:
ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) ϕ(k)(xi(k1),xj(k1),ej,i)
其中, ϕ ( k ) \phi^{(k)} ϕ(k)表示可微函数,例如MLP。在生成消息的过程中,可能会用到:

  • 节点自己当前的特征 x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k1)
  • 节点邻居当前的特征 x j ( k − 1 ) \mathbf{x}_j^{(k-1)} xj(k1)
  • 节点与其邻居间边的特征 e j , i \mathbf{e}_{j,i} ej,i

当然上述三者并不都是必须的,具体使用什么来生成节点的消息取决于GNN的构建者。

2.2 消息聚合

在本阶段,每个节点会聚合来自邻居的消息,也就是公式(1)中的:
□ j ∈ N ( i ) ( Message ) \square_{j \in \mathcal{N}(i)}(\text{Message}) jN(i)(Message)
其中 Message \text{Message} Message指代2.1节中每个节点的消息, N ( i ) \mathcal{N}(i) N(i)表示节点 i i i的邻域, □ \square 表示可微的、转置不变(permutation invariant)函数。转置不变指聚合邻居的消息的结果与邻居的聚合顺序无关,常见的包括sum,max,mean

2.3 消息更新

在本阶段,每个节点利用聚合自邻居节点的消息来生成自己的消息,也就是公式(1)中的:
γ ( k ) ( x i ( k − 1 ) , NeighborMsg ) \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \text{NeighborMsg}\right) γ(k)(xi(k1),NeighborMsg)
其中 N e i g h b o r M s g NeighborMsg NeighborMsg指代2.2节中每个节点聚合自邻居的消息, γ ( k ) \gamma^{(k)} γ(k)也表示可微函数,例如MLP。

2.4 消息传递机制小结

上述的消息传播机制可以用下图概括:

message_passing

其中图中的方框便是的便是聚合邻居的神经网络。

经过前面的介绍可知:空域角度定义的GNN间的不同之处便在于它们关于消息生成、消息聚合和消息更新的实现的不同

三.消息传播机制的示例

为了方便理解上面的消息传递机制,本节主要展示GCN、GraphSAGE中的消息传播机制。

3.1 GCN中的消息传播机制

GCN虽然是从谱域角度定义的,但同样从空域角度来对其进行解释,其所对应的消息传播机制如下:
h i ( l + 1 ) = σ ( 1 ∣ N ( i ) ∣ ∑ j ∈ N ( i ) 1 ∣ N ( j ) ∣ h j ( l ) W ( l ) ) h_i^{(l+1)} = \sigma(\frac{1}{\sqrt{|\mathcal{N}(i)|}}\sum_{j\in\mathcal{N}(i)}\frac{1}{\sqrt{|\mathcal{N}(j)|}}h_j^{(l)}W^{(l)}) hi(l+1)=σ(N(i) 1jN(i)N(j) 1hj(l)W(l))
生成消息
1 ∣ N ( j ) ∣ h j ( l ) W ( l ) \frac{1}{\sqrt{|\mathcal{N}(j)|}}h_j^{(l)}W^{(l)} N(j) 1hj(l)W(l)
消息聚合
σ ( 1 ∣ N ( i ) ∣ ∑ j ∈ N ( i ) Message ) \sigma(\frac{1}{\sqrt{|\mathcal{N}(i)|}}\sum_{j\in\mathcal{N}(i)}{\text{Message}}) σ(N(i) 1jN(i)Message)
其中 σ \sigma σ指非线性激活,通常为ReLU。

消息更新:消息聚合来的消息。

3.2 GraphSAGE中的消息传播机制

GraphSAGE中的消息传播机制如下所示:
h i ( l + 1 ) = σ ( W ( l ) ⋅ CONCAT ⁡ ( h i ( l ) , A G G ( { h j ( l ) , ∀ j ∈ N ( i ) } ) ) ) h_{i}^{(l + 1)}=\sigma\left(W^{(l)} \cdot \operatorname{CONCAT}\left(h_{i}^{(l)}, \mathrm{AGG}\left(\left\{h_{j}^{(l)}, \forall j \in N(i)\right\}\right)\right)\right) hi(l+1)=σ(W(l)CONCAT(hi(l),AGG({hj(l),jN(i)})))
其中 norm \text{norm} norm表示L2正则化。

消息生成
h j h_j hj
两阶段聚合(首先聚合邻居的消息,然后聚合自身的消息)
h N ( i ) ( l + 1 ) ← A G G ( { h j ( l ) , ∀ j ∈ N ( i ) } ) σ ( W ( l ) ⋅ CONCAT ⁡ ( h i ( l ) , h N ( i ) ( l + 1 ) ) ) h_{N(i)}^{(l + 1)} \leftarrow \mathrm{AGG}\left(\left\{h_{j}^{(l)}, \forall j \in N(i)\right\}\right) \\ \sigma\left(W^{(l)} \cdot \operatorname{CONCAT}\left(h_{i}^{(l)},h_{N(i)}^{(l + 1)}\right)\right) hN(i)(l+1)AGG({hj(l),jN(i)})σ(W(l)CONCAT(hi(l),hN(i)(l+1)))
其中 AGG \text{AGG} AGG在GraphSAGE中有三种实现:Mean、Pool或LSTM

消息更新:消息聚合来的消息。

四.结语

参考资料:

GNN中的消息传播机制是借助PyG之类的图神经网络框架来编写自己的消息传播GNN的基础,只有对其了解较为深刻,你才能更好的设计自己的GNN模型。

### 实现基于 PyTorch Geometric 的 GCN 和 GNN 消息传递机制 #### 使用 `MessagePassing` 类构建自定义层 为了在 PyTorch Geometric 中实现消息传递机制,通常会继承 `torch_geometric.nn.MessagePassing` 并重写其核心函数。这允许创建定制化的图卷积网络或其他类型的图神经网络。 ```python import torch from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import add_self_loops, degree class CustomGCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super(CustomGCNConv, self).__init__(aggr='add') # "Add" aggregation. self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Linearly transform node feature matrix. x = self.lin(x) # Step 3-5: Start propagating messages. return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) def message(self, x_j, edge_index, size): # Normalize node features by their degrees (message passing step). row, col = edge_index deg = degree(row, size[0], dtype=x_j.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] return norm.view(-1, 1) * x_j def update(self, aggr_out): # No additional transformation is applied during this stage. return aggr_out ``` 此代码片段展示了如何通过扩展 `MessagePassing` 来设计自己的 GCN 层[^2]。这里的关键在于: - **初始化 (`__init__`)** 定义了聚合方式(这里是加法),并设置了线性变换参数用于节点特征转换。 - **前向传播 (`forward`)** 添加自环到邻接矩阵,并执行线性映射操作;随后调用父类的方法启动消息传播过程。 - **消息计算 (`message`)** 计算每条边上的权重因子——即源节点度数平方根倒数乘以目标节点度数平方根倒数的结果作为规范化系数应用至邻居节点特征上。 - **更新状态 (`update`)** 将聚集后的输出直接返回给下一个模块使用而不做任何额外修改。 这种结构化的方式不仅适用于标准的 GCNs ,也可以很容易地调整成其他变体形式如 GraphSAGE 或者 GATs 等等。 对于更复杂的场景,则可能涉及到多跳传播、跳跃连接以及残差学习等多种高级特性。这些都可以在这个框架基础上进一步开发和完善[^1]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

斯曦巍峨

码文不易,有条件的可以支持一下

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值