消息传递图神经网络
图计算任务成功的关键是为节点生成节点表征,而我们要利用神经网络来学习节点的表征。 消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。
1 消息传递范式
如下图所示,展示了基于消息传递范式的聚合邻接节点信息来更新中心节点信息的过程:
图片来源于:Graph Neural Network • Introduction to Graph Neural Networks
- A节点是我们要更新的中心节点,对其更新需要聚合其邻接节点B、C、D的信息。图中黄色方框展示了在更新A节点之前B节点的更新过程:B的邻接节点A、C的信息经过变换后聚合到B节点,接着B节点信息与邻居节点聚合信息一起经过变换得到B节点的新的节点信息。同样的红色和绿色方框部分以同样的过程对C、D节点的进行了更新。实际上,这样的过程在所有的节点上都进行了一遍。我们可以递归的去理解这一过程。
- 这样邻接节点信息聚合到中心节点的过程就如蓝色框所示,A节点的信息得到了更新。而这样的过程会进行多次,之后产生的节点信息就作为节点表征。
对于上述”聚合邻接节点信息来更新中心节点信息的过程“,我们用用
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F
xi(k−1)∈RF表示
(
k
−
1
)
(k-1)
(k−1)层中节点
i
i
i的节点表征,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD 表示从节点
j
j
j到节点
i
i
i的边的属性,消息传递图神经网络可以描述为:
x
i
(
k
)
=
γ
(
k
)
(
x
i
(
k
−
1
)
,
□
j
∈
N
(
i
)
ϕ
(
k
)
(
x
i
(
k
−
1
)
,
x
j
(
k
−
1
)
,
e
j
,
i
)
)
,
\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),
其中
□
\square
□表示可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数。具有排列不变性的函数有,和函数、均值函数和最大值函数。
γ
\gamma
γ和
ϕ
\phi
ϕ表示可微分的函数,如MLPs(多层感知器)。此处内容来源于CREATING MESSAGE PASSING NETWORKS。
注:
- 神经网络的生成节点表征的操作称为节点嵌入(Node Embedding)。
- 未经过训练的图神经网络生成的节点表征还不是好的节点表征,通过监督学习对图神经网络做很好的训练,可以生成好的节点表征,可用于衡量节点之间的相似性。
- 节点表征与节点属性的区分:节点属性
data.x
是节点的第0层节点表征,第 h h h层的节点表征经过一次的节点间信息传递产生第 h + 1 h+1 h+1层的节点表征。
2 MessagePassing基类
我们将初步分析PyG中的MessagePassing
基类,它封装了“消息传递”的运行流程,通过继承此基类我们可以方便地构造一个图神经网络。如果是构造一个简单的图神经网络类,我们只需定义message()
方法(
ϕ
\phi
ϕ)、update()
方法(
γ
\gamma
γ),以及使用的消息聚合方案(aggr="add"
、aggr="mean"
或aggr="max"
)。
MessagePassing源码请参考这里,下面我们介绍一下该基类的基本参数和方法:
-
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
(对象初始化方法):
~aggr
:定义要使用的聚合方案(“add”、"mean "或 “max”);
~flow
:定义消息传递的流向("source_to_target "或 “target_to_source”);
~node_dim
:定义沿着哪个维度传播,默认值为-2
,也就是节点表征张量(Tensor)的中节点的维度的那一个维度。节点表征张量x
形状为[num_nodes, num_features]
,其第0维度(也是第-2维度)是节点维度,其第1维度(也是第-1维度)是节点表征维度,所以我们可以设置node_dim=-2
。
~注:MessagePassing(……)
等同于MessagePassing.__init__(……)
-
MessagePassing.propagate(edge_index, size=None, **kwargs)
:
~开始传递消息的起始调用,在此方法中message
、update
等方法被调用。
~它以edge_index
(边的端点的索引)和flow
(消息的流向)以及一些额外的数据为参数。
~edge_index
(Tensor
或SparseTensor
) :定义底层图形连接/消息传递流。如果输入的是稀疏矩阵SparseTensor
(所有节点的邻接矩阵),则形状为[N, M]
,行表示头节点,列表是尾节点。如果输入是稠密矩阵Tensor
,那么形状为[2, num_messages]
,第0行为尾节点,第1行为头节点,头指向尾。
~size
:定义输入的edge_index
的大小,如果设置size=None
,则认为邻接矩阵是对称的。基于非对称的邻接矩阵进行消息传递(当图为二部图时),需要传递参数size=(N, M)
。
~propagate()
方法首先检查edge_index
是否为SparseTensor
类型以及是否子类实现了message_and_aggregate()
方法,如是就执行子类的message_and_aggregate
方法;否则依次执行子类的message(),aggregate(),update()
三个方法。 -
MessagePassing.message(...)
:
~为edge_index中的每条边构造从节点 j j j到节点 i i i的消息(类似于 ϕ ( k ) ϕ^(k) ϕ(k))。我们用 i i i表示“消息传递”中的中心节点,用 j j j表示“消息传递”中的邻接节点。
~首先确定要给节点 i i i传递消息的边的集合:1)如果flow="source_to_target"
,则是 ( j , i ) ∈ E (j,i) \in \mathcal{E} (j,i)∈E的边的集合;2)如果flow="target_to_source"
,则是 ( i , j ) ∈ E (i,j) \in \mathcal{E} (i,j)∈E的边的集合。
~MessagePassing.message(...)
方法可以接收传递给MessagePassing.propagate(edge_index, size=None, **kwargs)
方法的所有参数,我们在message()
方法的参数列表里定义要接收的参数,例如我们要接收x,y,z
参数,则我们应定义message(x,y,z)
方法。
~传递给propagate()
方法的参数,如果是节点的属性的话,可以被拆分成属于中心节点的部分和属于邻接节点的部分,只需在变量名后面加上_i
或_j
。例如,我们自己定义的meassage
方法包含参数x_i
,那么首先propagate()
方法将节点表征拆分成中心节点表征和邻接节点表征,接着propagate()
方法调用message
方法并传递中心节点表征给参数x_i
。而如果我们自己定义的meassage
方法包含参数x_j
,那么propagate()
方法会传递邻接节点表征给参数x_j
。 -
MessagePassing.aggregate(...)
:
~聚集来自邻居的消息, □ j ∈ N ( i ) \square_{j \in \mathcal{N}(i)} □j∈N(i)。
~将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum
,mean
和max
。
~将message()
计算的输出作为第一个参数以并接收所有传递给propagate()
方法的参数。 -
MessagePassing.message_and_aggregate(...)
:
~将message()
和aggregate()
的计算融合到单个函数中。
~在一些场景里,邻接节点信息变换和邻接节点信息聚合这两项操作可以融合在一起,那么我们可以在此方法里定义这两项操作,从而让程序运行更加高效。只适用于稀疏矩阵。 -
MessagePassing.update(aggr_out, ...)
:
~对于每个节点 i ∈ V i\in \mathcal{V} i∈V,以类似于 γ ( k ) \gamma^{(k)} γ(k)的方式更新节点嵌入。此方法以aggregate
方法的输出为第一个参数,并接收所有传递给propagate()
方法的参数。
3 MessagePassing
子类实例
我们以继承MessagePassing
基类的GCNConv
类为例,学习如何通过继承MessagePassing
基类来实现一个简单的图神经网络。GCNConv
的数学定义为:
X
′
=
D
^
−
1
/
2
A
^
D
^
−
1
/
2
X
Θ
,
\mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
X′=D^−1/2A^D^−1/2XΘ,
其中
A
^
=
A
+
I
\hat{A}=A+I
A^=A+I,表示插入自环的邻接矩阵。
D
^
i
i
=
∑
j
=
0
A
^
i
j
\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}
D^ii=∑j=0A^ij,是度矩阵为对角阵。
x
i
(
k
)
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
deg
(
i
)
⋅
deg
(
j
)
⋅
(
Θ
⋅
x
j
(
k
−
1
)
)
,
\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right),
xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
其中,邻接节点的表征
x
j
(
k
−
1
)
\mathbf{x}_j^{(k-1)}
xj(k−1)首先通过与权重矩阵
Θ
\mathbf{\Theta}
Θ相乘进行变换,然后按端点的度
deg
(
i
)
,
deg
(
j
)
\deg(i), \deg(j)
deg(i),deg(j)进行归一化处理,最后进行求和。
这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边(如果不添加对角线为0,不能利用自身信息)。
- 对节点表征做线性转换。
- 计算归一化系数。
- 归一化邻接节点的节点表征。
- 将相邻节点表征相加("求和 "聚合)。
步骤1-3通常是在消息传递发生之前计算的。步骤4-5可以使用MessagePassing
基类轻松处理。该层的全部实现如下所示。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
GCNConv
继承了MessagePassing
并以"求和"作为领域节点信息聚合方式。从forward()
方法我们可以看到图运算的逻辑:
- 首先使用
add_self_loops
向边索引添加自循环边(步骤1); - 然后通过
torch.nn.Linear
对节点表征进行线性变换(步骤2); - 然后计算得到归一化系数
norm
形状为[num_edges,]
,公式中归一化系数是由每个节点的节点度得出的,在这里它被转换为每条边的节点度(步骤3); propagate()
方法也在forward
方法中被调用,propagate()
方法被调用后节点间的信息传递开始执行(步骤4、5)。
以上是一个仅包含一次消息传递过程的图神经网络,我们初始化并调用他:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
# Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
torch.Size([2708, 64])
而我们通过串联多个这样的简单图神经网络,我们就可以构造复杂的图神经网络模型。
4 MessagePassing
类中方法覆写
4.1 message
方法的覆写
前面我们介绍过,传递给propagate()
方法的参数,如果是节点的属性([2,num_messages]
)的话,可以被拆分成属于中心节点的部分和属于邻接节点的部分,只需在变量名后面加上_i
或_j
。而这一过程是通过MessagePassing
类中的__collect__()
方法实现的。
def __collect__(self, args, edge_index, size, kwargs):
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
out = {}
for arg in args:
if arg[-2:] not in ['_i', '_j']:
out[arg] = kwargs.get(arg, Parameter.empty)
else:
dim = 0 if arg[-2:] == '_j' else 1
data = kwargs.get(arg[:-2], Parameter.empty)
if isinstance(data, (tuple, list)):
assert len(data) == 2
if isinstance(data[1 - dim], Tensor):
self.__set_size__(size, 1 - dim, data[1 - dim])
data = data[dim]
if isinstance(data, Tensor):
self.__set_size__(size, dim, data)
data = self.__lift__(data, edge_index,
j if arg[-2:] == '_j' else i)
out[arg] = data
# 省略n行
return out
在这里我们想重写一下message
,使其包含邻接节点的部分x_j
和中心节点的度deg_i
:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm, deg_i):
# x_j has shape [E, out_channels]
# deg_i has shape [E, 1]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j * deg_i
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
torch.Size([2708, 64])
4.2 aggregate
方法的覆写
我们覆写aggregate
,当其被调用时打印"aggregate
is called",同时我们查看此时传递的参数,会发现在super(GCNConv, self).__init__(aggr='add')
中传递给aggr
参数的值被存储到了self.aggr
属性中。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm, deg_i):
# x_j has shape [E, out_channels]
# deg_i has shape [E, 1]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j * deg_i
def aggregate(self, inputs, index, ptr, dim_size):
print('self.aggr:', self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)
self.aggr: add
aggregate
is called
torch.Size([2708, 64])
4.3 message_and_aggregate
方法的覆写
在某些情况下,“消息传递”与“消息聚合”可以融合在一起,此时我们可以覆写message_and_aggregate
方法,一块实现“消息传递”与“消息聚合”,这样能使程序的运行更加高效。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm, deg_i):
# x_j has shape [E, out_channels]
# deg_i has shape [E, 1]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j * deg_i
def aggregate(self, inputs, index, ptr, dim_size):
print('self.aggr:', self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
def message_and_aggregate(self, adj_t, x, norm):
print('`message_and_aggregate` is called')
# 没有实现真实的消息传递与消息聚合的操作
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
# print(h_nodes.shape)
message_and_aggregate
is called
此处我们将节点的属性矩阵转变为稀疏矩阵SparseTensor
,这样才能满足propagate()
方法调用message_and_aggregate
方法的条件if (isinstance(edge_index, SparseTensor) and self.fuse and not self.__explain__):
。运行程序后我们可以看到,虽然我们同时覆写了message
方法和aggregate
方法,然而只有message_and_aggregate
方法被执行。
4.4 update
方法覆写
from torch_geometric.datasets import Planetoid
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_sparse import SparseTensor
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm, deg_i):
# x_j has shape [E, out_channels]
# deg_i has shape [E, 1]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j * deg_i
def aggregate(self, inputs, index, ptr, dim_size):
print('self.aggr:', self.aggr)
print("`aggregate` is called")
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
def message_and_aggregate(self, adj_t, x, norm):
print('`message_and_aggregate` is called')
# 没有实现真实的消息传递与消息聚合的操作
def update(self, inputs, deg):
print(deg)
return inputs
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
# print(h_nodes.shape)
message_and_aggregate
is called
tensor([[4.],
[4.],
[6.],
…,
[2.],
[5.],
[5.]])
update
方法接收聚合的输出作为第一个参数,此外还可以接收传递给propagate
方法的任何参数。在上方的代码中,我们覆写的update
方法接收了聚合的输出作为第一个参数,此外接收了传递给propagate
的deg
参数。
总结
消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。该范式包含这样三个步骤:(1)邻接节点信息变换、(2)邻接节点信息聚合到中心节点、(3)聚合信息变换。在PyG中,MessagePassing
基类是所有基于消息传递范式的图神经网络的基类,它大大地方便了我们对图神经网络的构建。
参考资料