1、消息传递原理
为节点生成节点表征(Node Representation)是图计算任务成功的关键,我们要利用神经网络来学习节点表征。消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。消息传递范式因为简单、强大的特性,于是被人们广泛地使用。遵循消息传递范式的图神经网络被称为消息传递图神经网络。
具体来说就是:
1)首先从邻居获取信息:计算红色节点周围的四个邻居节点的消息总和。
2)对获得的信息加以利用:将获得的消息与(k-1)时刻红色节点本身的表示组合起来,计算得到k时刻的红色节点表示。总的来说也就是利用结点周边的特征来获取更有用的特征
消息传递图神经网络遵循上述的“聚合邻接节点信息来更新中心节点信息的过程”,来生成节点表征。用
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F
xi(k−1)∈RF表示
(
k
−
1
)
(k-1)
(k−1)层中节点
i
i
i的节点表征,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD 表示从节点
j
j
j到节点
i
i
i的边的属性,消息传递图神经网络可以描述为
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
,
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
其中
□
\square
□表示可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数。具有排列不变性的函数有,sum()
函数、mean()
函数和max()
函数。
γ
\gamma
γ和
ϕ
\phi
ϕ表示可微分的函数,如MLPs(多层感知器)。此处内容来源于CREATING MESSAGE PASSING NETWORKS。
2、MessagePassing
基类的运行流程。
1)定义message方法,为各条边创建要传递给节点
i
i
i的消息,即实现
ϕ
\phi
ϕ函数。
2)定义aggregation 函数,将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum
, mean
和max
。
3)定义update函数,为每个节点
i
∈
V
i \in \mathcal{V}
i∈V更新节点表征,即实现
γ
\gamma
γ函数。
3、复现一个一层的图神经网络的构造,总结通过继承MessagePassing
基类来构造自己的图神经网络类的规范
3.1
我们以继承MessagePassing
基类的GCNConv
类为例,学习如何通过继承MessagePassing
基类来实现一个简单的图神经网络。
GCNConv
的数学定义为
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
Θ
⋅
x
j
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
其中,邻接节点的表征
x
j
(
k
−
1
)
\mathbf{x}_j^{(k-1)}
xj(k−1)首先通过与权重矩阵
Θ
\mathbf{\Theta}
Θ相乘进行变换,然后按端点的度
deg
(
i
)
,
deg
(
j
)
\deg(i), \deg(j)
deg(i),deg(j)进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边。
- 对节点表征做线性转换。
- 计算归一化系数。
- 归一化邻接节点的节点表征。
- 将相邻节点表征相加("求和 "聚合)。
步骤1-3通常是在消息传递发生之前计算的。步骤4-5可以使用MessagePassing
基类轻松处理。该层的全部实现如下所示。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
本文主要内容来自datawhale开源课程