Pytorch-geometric: Creating Message Passing Networks 构建消息传递网络教程
一、背景
将卷积运算推广到不规则域通常表示为邻局聚合(neighborhood aggregation)或消息传递(neighborhood aggregation)模式。
x
i
(
k
−
1
)
∈
R
1
×
D
\mathbf{x}^{(k-1)}_i \in \mathbb{R}^{1 \times D}
xi(k−1)∈R1×D表示节点
i
i
i在第
(
k
−
1
)
(k-1)
(k−1)层的节点特征,
e
j
,
i
∈
R
1
×
F
\mathbf{e}_{j,i} \in \mathbb{R}^{1 \times F}
ej,i∈R1×F表示节点
j
j
j到节点的
i
i
i边特征(可选的),消息传递图神经网络可以描述为:
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
,
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
其中,
□
\square
□表示可微且置换不变的聚合函数(aggregation function),例如, sum
、mean
或max
,消息函数(message function)
ϕ
\phi
ϕ 和更新函数(update function)
γ
\gamma
γ均为可微函数,例如MLP。
值得注意的是,一般GNN论文中通常给出的是聚合邻居信息的Aggregator和更新节点表示Updator,其Aggregator对应pytorch-geometric(PyG)中的消息函数和聚合函数。GNN本质上还是在做特征传播。
x N i ( k ) = AGGREGATE ( k ) ( { x j ( k − 1 ) , ∀ j ∈ N i } ) \mathbf{x}_{\mathcal{N}_{i}}^{(k)}=\text { AGGREGATE }_{(k)}\left(\left\{\mathbf{x}_{j}^{(k-1)}, \forall j \in \mathcal{N}_{i}\right\}\right) xNi(k)= AGGREGATE (k)({xj(k−1),∀j∈Ni}) x i ( k ) = σ ( W ( k ) ⋅ [ x i ( k − 1 ) ∥ x N i ( k ) ] ) \mathbf{x}_{i}^{(k)}=\sigma\left(\mathbf{W}^{(k)} \cdot\left[\mathbf{x}_{i}^{(k-1)} \| \mathbf{x}_{\mathcal{N}_{i}}^{(k)}\right]\right) xi(k)=σ(W(k)⋅[xi(k−1)∥xNi(k)])
例如,在GraphSage中,消息函数直接获取邻居节点 j ∈ N i j \in \mathcal{N}_{i} j∈Ni在第 k − 1 k-1 k−1层的嵌入,然后使用mean、max或LSTM作为聚合函数,更新函数将邻居中间嵌入和目标节点 i i i自身嵌入拼接后做线性变化。
α
i
j
=
exp
(
Leaky ReLU
(
a
T
[
W
x
i
∥
W
x
j
]
)
)
∑
k
∈
N
i
exp
(
Leaky ReLU
(
a
T
[
W
x
i
∥
W
x
k
]
)
)
\alpha_{i j}=\frac{\exp \left(\text { Leaky ReLU }\left(\mathbf{a}^{T}\left[\mathbf{W} \mathbf{x}_{i} \| \mathbf{W} \mathbf{x}_{j}\right]\right)\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(\text { Leaky ReLU }\left(\mathbf{a}^{T}\left[\mathbf{W} \mathbf{x}_{i} \| \mathbf{W} \mathbf{x}_{k}\right]\right)\right)}
αij=∑k∈Niexp( Leaky ReLU (aT[Wxi∥Wxk]))exp( Leaky ReLU (aT[Wxi∥Wxj]))
x
i
′
=
∥
k
=
1
K
σ
(
∑
j
∈
N
i
α
i
j
k
W
k
x
j
)
\mathbf{x}_{i}^{\prime}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} \mathbf{x}_{j}\right)
xi′=∥k=1Kσ
j∈Ni∑αijkWkxj
又例如,在GAT中,消息函数根据注意力系数对节点嵌入进行归一化,然后使用"add"
作为聚合函数。
二、MessagePassing基类
PyG的torch_geometric.nn
中提供了MessagePassing
基类,它通过自动处理消息传播来帮助创建此类消息传递图神经网络。用户只需重新定义
ϕ
\phi
ϕmessage()
和
γ
\gamma
γupdate()
及aggregation聚合方式(函数),例如aggr="add"
, aggr="mean"
or aggr="max"
,就可以实现自己GNN模型。
借助以下4个方法可实现上述目的:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
:定义要使用的聚合方案("add"
,"mean"
或"max"
)和消息传递的流向("source_to_target"
或"target_to_source"
)。 此外,node_dim
属性指明沿哪个轴传播。
MessagePassing.propagate(edge_index, size=None, **kwargs)
: 开始传播消息的初始调用。它接收边索引edge_index
和构造消息所需的所有其他数据,来更新节点嵌入。propagate()
不仅可以在[N, N]
的方矩中交换消息,还可通过传入size=(N, M)
作为附加参数传递来交换形如[N, M]
的稀疏分配矩阵(例如,推荐系统中的二部图)中的消息。如果size
设为None
,则矩阵为方阵。
MessagePassing.message(...)
:类似
ϕ
\phi
ϕ,构造每条边到节点
i
i
i的消息。若 flow="source_to_target"
则
(
j
,
i
)
∈
E
(j,i) \in \mathcal{E}
(j,i)∈E和flow="target_to_source"
则
(
i
,
j
)
∈
E
(i,j) \in \mathcal{E}
(i,j)∈E。它可接受最初传递给propagate()
的任何参数。 此外,传递给propagate()
的tensors可通过添加后缀_i
和_j
到变量名(例如,x_i
和x_j
)映射到对应的节点
i
i
i和
j
j
j。根据习惯,通常用
i
i
i表示聚合信息的中心节点(目标target),并用
j
j
j表示邻居节点(源source)。
MessagePassing.update(aggr_out, ...)
:类似
γ
\gamma
γ,更新每个节点
i
∈
V
i \in \mathcal{V}
i∈V的嵌入。聚合操作的输出aggr_out
作为其第一个参数,以及最初传递给propagate()
的任何参数。
三、例子
接下来,将通过MessagePassing实现GCN和EdgeConv来作进一步介绍。为便于表示,将节点特征表示为行向量。
3.1 实现GCN层
矩阵形式的GCN层:
X
(
k
)
=
σ
(
A
^
X
(
k
−
1
)
W
(
k
)
)
\mathbf{X}^{(k)} =\sigma\left(\hat{\mathbf{A}} \mathbf{X}^{(k-1)} \mathbf{W}^{(k)} \right)
X(k)=σ(A^X(k−1)W(k))其中,
A
^
=
D
~
−
1
2
A
~
D
~
−
1
2
∈
R
N
×
N
\hat{\mathbf{A}}=\tilde{\mathbf{D}}^{-\frac{1}{2}} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-\frac{1}{2}} \in \mathbb{R}^{N \times N}
A^=D~−21A~D~−21∈RN×N为自环归一化邻接矩阵,
A
~
=
A
+
I
\tilde{\mathbf{A}}=\mathbf{A}+\mathbf{I}
A~=A+I在原始邻接矩阵上加自环连接,
D
~
=
D
+
I
\tilde{\mathbf{D}}=\mathbf{D}+\mathbf{I}
D~=D+I,
X
(
k
−
1
)
∈
R
N
×
D
\mathbf{X}^{(k-1)}\in \mathbb{R}^{N \times D}
X(k−1)∈RN×D,
W
(
k
)
∈
R
D
×
D
\mathbf{W}^{(k)}\in \mathbb{R}^{D\times D}
W(k)∈RD×D。
将 A ^ \hat{\mathbf{A}} A^在节点层面展开:
- A ~ \tilde{\mathbf{A}} A~先左乘 D ~ − 1 2 \tilde{\mathbf{D}}^{-\frac{1}{2}} D~−21做行变化,即对 A ~ \tilde{\mathbf{A}} A~的每一行 A ~ i : \tilde{\mathbf{A}}_{i:} A~i:按节点 i i i的度 d e g ( i ) − 1 2 deg(i)^{-\frac{1}{2}} deg(i)−21进行归一化(假设 A ~ \tilde{\mathbf{A}} A~为指示矩阵,除自己之外,只有节点 i i i的一阶邻居 j ∈ N ( i ) j \in \mathcal{N}(i) j∈N(i)的值 A ~ i j \tilde{\mathbf{A}}_{ij} A~ij为1)。
- D ~ − 1 2 A ~ \tilde{\mathbf{D}}^{-\frac{1}{2}} \tilde{\mathbf{A}} D~−21A~再右乘 D ~ − 1 2 \tilde{\mathbf{D}}^{-\frac{1}{2}} D~−21做列变化,即对每一列 ( D ~ − 1 2 A ~ ) : j (\tilde{\mathbf{D}}^{-\frac{1}{2}} \tilde{\mathbf{A}})_{:j} (D~−21A~):j按节点 j j j的度 d e g ( j ) − 1 2 deg(j)^{-\frac{1}{2}} deg(j)−21再做归一化。
此时, A ^ i j = A ~ i j d e g ( i ) − 1 2 d e g ( j ) − 1 2 \hat{\mathbf{A}}_{ij}=\tilde{\mathbf{A}}_{ij} deg(i)^{-\frac{1}{2}} deg(j)^{-\frac{1}{2}} A^ij=A~ijdeg(i)−21deg(j)−21,即将满足 A ~ i j ≠ 1 \tilde{\mathbf{A}}_{ij} \neq 1 A~ij=1的边 e i j e_{ij} eij 对应的节点对 < i , j > <i,j> <i,j>的度来进行归一化。
A ^ \hat{\mathbf{A}} A^ 的行或列之和并不为一,它不同于可视为概率转移矩阵的简单行归一化 D ~ − 1 A ~ \tilde{\mathbf{D}}^{-1} \tilde{\mathbf{A}} D~−1A~ 或列归一化 A ~ D ~ − 1 \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1} A~D~−1 。
A ^ \hat{\mathbf{A}} A^ 右乘 X ( k − 1 ) \mathbf{X}^{(k-1)} X(k−1),相当于用 A ^ \hat{\mathbf{A}} A^的每一行的系数对节点的行向量矩阵做线性组合。其中,节点 i i i在第 k k k层的表示 x ( k ) ∈ R 1 × D \mathbf{x}^{(k)} \in \mathbb{R}^{1 \times D} x(k)∈R1×D是由 A ^ i : ∈ R 1 × N \hat{\mathbf{A}}_{i:} \in \mathbb{R}^{1 \times N} A^i:∈R1×N乘以 X ( k − 1 ) ∈ R N × D \mathbf{X}^{(k-1)} \in \mathbb{R}^{N \times D} X(k−1)∈RN×D,等价于直接以加权系数 A ^ i : \hat{\mathbf{A}}_{i:} A^i:对节点 i i i的一阶邻居 N ( i ) \mathcal{N}(i) N(i)以及 i i i自己的节点表示做线性组合(加权求和)。
由此,可得到空域视角的GCN层的定义:
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
x
j
(
k
−
1
)
W
(
k
)
)
+
b
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{x}_j^{(k-1)} \mathbf{W}^{(k)} \right) + \mathbf{b},
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(xj(k−1)W(k))+b,
其中, 邻居节点的特征先经权重矩阵 W ( k ) \mathbf{W}^{(k)} W(k) 做变换,再按它们的度做归一化,最后求和。最后,将偏置向量应用于聚合输出。
GCN公式可分为以下步骤:
- 将自环连接加到邻接矩阵上
- 线性变换节点特征矩阵
- 计算归一化系数
- 归一化节点特征(制作message的过程)
- 使用
"add"
方法聚合节点特征(先汇聚邻居节点特征,再和目标节点特征合并) - 加上偏置向量(bias为可选项)。
第1-3步通常在消息传递前计算,4-5步可用MessagePassing
基类轻松实现。完整实现如下所示:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
# in_channels为输入节点特征维度, out_channels为输出节点特征维度
# 初始化GCN层中的线性变换权重矩阵和bias向量
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.Tensor(out_channels))
self.reset_parameters()
def reset_parameters(self):
# 参数初始化
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# 节点特征矩阵x的shape为[N, in_channels]
# 边索引edge_index的shape为[2, E]
# Step 1: 将自环连接加到邻接矩阵上
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 对节点特征矩阵做线性变换
x = self.lin(x)
# Step 3: 计算归一化系数
row, col = edge_index # 分别取出边索引的两部分
# 由于GCN一般将图视为无向,row或col中分别包含所有节点的索引,故可根据col统计节点的度
deg = degree(col, x.size(0), dtype=x.dtype) # 度对角矩阵
deg_inv_sqrt = deg.pow(-0.5) # 对角元素开负根号
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
# 归一化系数实际对应每条边,直接用边索引取度相乘即可
# norm的shape为[E, 1], E为边数量
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5:开始传播消息
out = self.propagate(edge_index, x=x, norm=norm)
# Step 6: 加偏置向量
out += self.bias
return out
def message(self, x_j, norm):
# x_j的shape为[E, out_channels]
# Step 4: 归一化节点特征(先将norm系数变为列向量,再和x_j做点乘)
return norm.view(-1, 1) * x_j
GCNConv
继承了使用"add"
聚合操作的MessagePass
。GCN层的所有计算逻辑都包含在其forward
方法中。
在计算好归一化系数norm
后(在GCN中norm
固定),将调用propagate()
,该函数内部会调用message()
、update()
和aggregate()
。除了edge_index
, 节点嵌入x
和归一化系数norm
将作为GCN消息传播的附加参数。
在message()
函数中,需通过norm
对邻居节点特征进行归一化。这里,x_j
表示一个 a lifted tensor,它包含每个边的source源节点特征,即每个节点的邻居。
以上就是创建一个简单的消息传递层所需的全部内容。此层可用作深层GNN的基础模块。初始化和调用它很简单:
conv = GCNConv(16, 32)
x = conv(x, edge_index)
3.2 实现EdgeConv层
边卷积层可以处理处理图或点云,它在数学上定义为:
x
i
(
k
)
=
max
j
∈
N
(
i
)
h
Θ
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
−
x
i
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right),
xi(k)=j∈N(i)maxhΘ(xi(k−1),xj(k−1)−xi(k−1)), 其中,
h
Θ
h_{\mathbf{\Theta}}
hΘ表示MLP。 与GCN层类似,可使用MessagePassing
类来实现它,聚合方式将使用"max"
:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
# x_i has shape [E, in_channels]
# x_j has shape [E, in_channels]
tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels]
return self.mlp(tmp)
在message()
函数内部,self.mlp
用于变换目标节点的特征x_i
和每条边
(
j
,
i
)
∈
E
(j,i) \in \mathcal{E}
(j,i)∈E的相对源节点特征 x_j - x_i
。边卷积实际为是动态卷积,对GNN的每一层都在特征空间使用knn最近邻来重新计算图结构。
参考文献
[1] Pytorch-geometric官方文档-Creating Message Passing Networks
[2] https://blog.csdn.net/morgan777/article/details/121183287
[3] https://zhuanlan.zhihu.com/p/130796040
[4] https://blog.csdn.net/weixin_39925939/article/details/121360884