本文参考自datawhale2021.6学习:图神经网络
本章目录
1 消息传递范式
消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接
消息传递神经网络(MPNN)是一种框架,其前向传递有两个阶段:消息传递阶段(Message Passing)、读出阶段(Readout),这里先介绍消息传递阶段
1.1 消息传递的三个函数
- 三个函数分为:
- 各边要传递的消息的创建 ϕ \phi ϕ、消息聚合 □ \square □ 、节点表征的更新 γ \gamma γ 三个步骤
- 对三个函数的要求:
- 要求上述三个函数均可微
- 且消息聚合具有排列不变性(函数输出结果与输入参数的排列无关,即对节点的排列不敏感)。
- 具有排列不变性的函数有和函数、均值函数和最大值函数。
- 消息传递的数学描述:
- 用
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F
xi(k−1)∈RF表示
(
k
−
1
)
(k-1)
(k−1)层中节点
i
i
i的节点属性,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD 表示从节点
j
j
j到节点
i
i
i的边的属性,消息传递可以描述为:
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))
- 用
x
i
(
k
−
1
)
∈
R
F
\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F
xi(k−1)∈RF表示
(
k
−
1
)
(k-1)
(k−1)层中节点
i
i
i的节点属性,
e
j
,
i
∈
R
D
\mathbf{e}_{j,i} \in \mathbb{R}^D
ej,i∈RD 表示从节点
j
j
j到节点
i
i
i的边的属性,消息传递可以描述为:
1.2 节点嵌入与节点表征
- 节点嵌入(Node Embedding):神经网络生成节点表征的操作,或节点表征也称节点嵌入
- 这里节点嵌入仅指代前者
- 好的节点表征可以衡量节点间的相似性,需要通过图神经网络训练得到
2 PyG的MessagePassing基类
2.1 属性
class MessagePassing(torch.nn.Module):
def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2):
super(MessagePassing, self).__init__()
# 此处省略n行代码
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max',None]
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
self.fuse = self.inspector.implements('message_and_aggregate')
# 此处省略n行代码
- aggr:定义要使用的聚合方案,默认add
- flow:定义消息传递的流向,从而确定给某节点传递消息的边的集合,默认s→t
- 用 i i i 表示目标节点, j j j 表示邻接节点
- flow=‘source_to_target’ :target表入,即传递信息的边的集合为 ( j , i ) ∈ E (j,i)\in\mathcal{E} (j,i)∈E
- flow=‘target_to_source’ :target表出,即传递信息的边的集合为 ( i , j ) ∈ E (i,j)\in\mathcal{E} (i,j)∈E
- node_dim:定义scatter沿着哪个轴线传播,默认-2
- fuse:检查是否实现了message_and_aggregate()方法,不需要自己定义
2.2 方法
class MessagePassing(torch.nn.Module):
# 此处省略n行代码
self.fuse = self.inspector.implements('message_and_aggregate')
self.node_dim = node_dim
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
# 此处省略n行代码
# 检查edge_index是否SparseTensor类型
# 检查是否实现了message_and_aggregate()方法,是就执行该方法,再执行update方法
if (isinstance(edge_index, SparseTensor) and self.fuse and not self.__explain__):
coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs)
# message_and_aggregate
msg_aggr_kwargs = self.inspector.distribute('message_and_aggregate', coll_dict)
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
# update
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
# 上述检查不通过,依次执行message(),aggregate(),update()方法
elif isinstance(edge_index, Tensor) or not self.fuse:
coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
# message
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
# aggregate
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
# update
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
def message(self, x_j):
# 按需要覆写或不写
return x_j
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
# 按需要覆写或不写
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
def update(self, inputs):
# 按需要覆写或不写
return inputs
def message_and_aggregate(self, adj_t, x, norm):
# 按需要覆写或不写
return x
- propagate(edge_index, size=None, **kwargs)
- 调用以传递消息,在此方法中
message
、aggregate
、update
等方法被调用 - 若检测到
message_and_aggregate
和edge_index为SparseTensor,则即使message
和aggregate
存在也不调用,而是调用message_and_aggregate
- 可将节点属性拆分成中心节点和邻接节点,对拆分的数据有格式要求,必须为 [num_nodes, *]。拆分如属性 x i x_i xi 和邻接节点属性 x j x_j xj,度 d e g i 、 d e g j deg_i、deg_j degi、degj
- size=None默认邻接矩阵对称,若是非对称的邻接矩阵(如二部图)则要传递参数 size=(N,M)
- 调用以传递消息,在此方法中
- message (写入需要的参数…)
- 实现 ϕ \phi ϕ 函数,创建各边要传递的邻接节点消息
- 可以接收传递给
propagate
方法的任何参数,只要在其中进行定义。如def message(self,x_j)而非 def message(self,x_j=x_j)
- aggregate (inputs, …)
- 实现消息聚合
- 关于scatter(src,index,dim=-1,out,dim_size,reduce=‘sum’):按照dim的操作方向, 将src的元素加到index指示的位置去。参考torch_scatter.scatter或torch_scatter.scatter 区别scatter_
propagate
调用时,传入给inputs的是message的输出
- message_and_aggregate (写入需要的参数…)
- 一些场景里 ϕ \phi ϕ 和聚合可以融合在一起操作,就可以在该方法里定义这两项操作,使程序运行更加高效
- update (inputs, …)
- 节点表征的更新,可以接收传递给
propagate
方法的任何参数 propagate
调用时inputs输入的是aggregate
的输出
- 节点表征的更新,可以接收传递给
3 GCNConv的实现
3.1 数学定义
x
i
k
=
∑
j
∈
N
(
i
)
∪
{
i
}
1
d
(
v
i
)
⋅
d
(
v
j
)
⋅
(
Θ
⋅
x
i
k
−
1
)
x_i^k = \sum_{j_\in \mathcal{N}(i) \cup \{i\}}\frac{1}{\sqrt{d(v_i)}\cdot \sqrt{d(v_j)}}\cdot \left( \Theta\cdot x_i^{k-1}\right)
xik=j∈N(i)∪{i}∑d(vi)⋅d(vj)1⋅(Θ⋅xik−1)
或
X
k
=
D
^
−
1
2
A
^
D
^
−
1
2
Θ
X
k
−
1
X^k = \hat D^{-\frac{1}{2}}\hat A\hat D^{-\frac{1}{2}}\Theta X^{k-1}
Xk=D^−21A^D^−21ΘXk−1
A
^
=
A
+
I
\hat A=A+I
A^=A+I 加入了自循环的邻接矩阵,
D
^
\hat D
D^ 是由
A
^
\hat A
A^ 计算的度矩阵
矩阵A行表出,列表入。左矩阵D对对应的出节点×,右矩阵D对对应的入节点×
3.2 代码实现
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# 线性变换层 Θ
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x 形状 [N, in_channels]
# edge_index 形状 [2, E]
# 添加自环的边
# edge_index形状为 [2,E+N]
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 节点属性做线性变换
x = self.lin(x)
# Compute normalization.
row, col = edge_index # row从节点0开始一直顺序排 e.g.[0,0,0,1,1…]
deg = degree(col, x.size(0), dtype=x.dtype) # 计算度矩阵
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# 若要将edge_index改写为SparseTensor
# adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
# 调用propagate传递信息
return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1))) # 若一个数据可以被拆分成属于中心节点的部分和属于邻接节点的部分,其形状必须是 [num_nodes, *],所以需要将deg的形状进行变换
# return self.propagate(edge_index, x=x, norm=norm)
# return self.propagate(adjmat, x=x, norm=norm, deg=deg.view((-1, 1)))
# 覆写消息构建函数 Φ
def message(self, x_j, norm, deg_i):
# x_j 是邻接节点矩阵,形状为 [E+N, out_channels]
# 这里flow = 'source_to_target',因此x_j行排序如row
# deg_i 是col排序的点的度
return norm.view(-1, 1) * x_j # 将每个邻接节点正则化,返回形状同 x_j
# 不需要覆写aggregate和update
# 这里未实现message_and_aggregate
# 也可以覆写aggregate,举个例子
def aggregate(self, inputs, index, ptr, dim_size):
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
# index是中心节点,根据flow在此其排序同col
# dim_size = 节点数
# 覆写函数时,传入的参数不要写 y=y 这种格式
# 调用网络
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64) # 类属性的定义
h_nodes = net(data.x, data.edge_index) # 调用forward,输入参数
print(h_nodes.shape)
- 关于稀疏矩阵:torch的稀疏矩阵或torch的稀疏矩阵
- degree用于计算节点出/入度:顺序即为节点序号[0,1,…,2707]
无注释代码
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def aggregate(self, inputs, index, ptr, dim_size):
return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64) # 类属性的定义
h_nodes = net(data.x, data.edge_index) # 调用forward,输入参数
print(h_nodes.shape)
4 作业
- 请总结
MessagePassing
基类的运行流程。
- 首先创建每条边上要传递的邻接节点的信息
- 其次对中心节点接收到的消息进行聚合
- 最后更新节点表征
- 请复现一个一层的图神经网络的构造,总结通过继承
MessagePassing
基类来构造自己的图神经网络类的规范。
- GNN规范:
- 属性如神经网络层、继承
flow
、aggr
等属性 - 定义
forward
方法,传入节点矩阵x
与边edge_index
- 添加自环边
- 节点属性变换
- 建立度矩阵,计算正则公式
- (上述两步也可在message中完成)
- 调用
propagate
,传入参数edge_index
、方法message、aggregate、update
要用到的参数如norm
、x
、deg
等 - 返回最终
update
后的节点表征
- 覆写有关函数
message
传入邻接节点信息x_j
,正则化公式norm
- 属性如神经网络层、继承
- 复现一个一层图神经网络
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def update(self,aggr_output):
return F.relu(aggr_output)