一、消息传递范式介绍
消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到不规则数据领域,实现了图与神经网络的连接。此范式包含三个步骤:(1)邻接节点信息变换;(2)邻接节点信息聚合到中心节点;(3)聚合信息变换。
消息传递图神经网络可以描述为:
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
,
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F
xi(k−1)∈RF表示(k-1)层中节点i的节点特征,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD表示从节点j到节点i的边的特征,
□
\square
□表示可微分的、具有排列不变形的函数,具有排列不变形的函数有和函数、均值函数和最大值函数。
γ
\gamma
γ和
ϕ
\phi
ϕ表示可微分的函数。
二、Pytorch Geometric中的MessagePassing基类
Pytorch Geometric提供了MessagePassing类,实现了消息传播的自动处理,继承该基类可以方便地构造消息传递图神经网络,我们只需要定义函数 ϕ \phi ϕ(即message函数)和函数 γ \gamma γ(即update函数),以及消息聚合方案(aggr=“add”、aggr="mean"或aggr=“max”)。
-
MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2):
aggr: 定义要使用的聚合方案(“add”、“mean"或"max”)
flow: 定义消息传递的流向(“source_to_target"或"target_to_source”)
node_dim: 定义沿着哪个轴线传播 -
MessagePassing.propagate(edge_index, size=None, **kwargs):
开始传播消息的起始调用。它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。
size=(N,M)设置对称邻接矩阵的形状。 -
MessagePassing.message(…)接受最初传递给propagate函数的所有参数。
-
MessagePassing.aggregate(…)将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum,mean和max。
-
MessagePassing.message_and_aggregate(…)融合了邻接节点信息变换和邻接节点信息聚合。
-
MessagePassing.update(aggr_out, …)为每个节点更新节点表征,即实现 γ \gamma γ函数。该函数以聚合函数的输出为第一参数,并接收所有传递给propagate函数的参数。
三、继承MessagePassing类的GCNConv
GCNConv的数学定义为:
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
Θ
⋅
x
j
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
其中相邻节点的特征通过权重矩阵
Θ
\mathbf{\Theta}
Θ进行转换,然后按端点的度进行归一化处理,最后进行加总。这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边。
- 线性转换节点特征矩阵。
- 计算归一化系数。
- 归一化j中的节点特征。
- 将相邻节点特征相加。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
# 初始化和调用
conv = GCNConv(16, 32)
x = conv(x, edge_index)
四、复写message函数
class GCNConv(MessagePassing):
def forward(self, x, edge_index):
# ....
return self.propagate(edge_index, x=x, norm=norm, d=d)
def message(self, x_j, norm, d_i):
# x_j has shape [E, out_channels]
return norm.view(-1, 1) * x_j * d_i
五、覆写aggregate函数
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def forward(self, x, edge_index):
# ....
return self.propagate(edge_index, x=x, norm=norm, d=d)
def aggregate(self, inputs, index, ptr, dim_size):
print(self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
六、覆写aggregate函数
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def forward(self, x, edge_index):
# ....
return self.propagate(edge_index, x=x, norm=norm, d=d)
def aggregate(self, inputs, index, ptr, dim_size):
print(self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
七、覆写message_and_aggregate函数
from torch_sparse import SparseTensor
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def forward(self, x, edge_index):
# ....
adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
return self.propagate(adjmat, x=x, norm=norm, d=d)
def message(self, x_j, norm, d_i):
# x_j has shape [E, out_channels]
return norm.view(-1, 1) * x_j * d_i # 这里不管正确性
def aggregate(self, inputs, index, ptr, dim_size):
print(self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
def message_and_aggregate(self, adj_t, x, norm):
print('`message_and_aggregate` is called')
八、覆写update函数
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
def update(self, inputs: Tensor) -> Tensor:
return inputs