1、MessagePassing
基类的运行流程
MessagePassing类封装了“消息传递”的运行流程,可以用于构造消息传递图神经网络。主要参数定义如下:
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
(对象初始化方法):
-
aggr
:定义要使用的聚合方案("add"、"mean "或 "max"); -
flow
:定义消息传递的流向("source_to_target "或 "target_to_source"); -
node_dim
:定义沿着哪个维度传播,默认值为-2
,
对于简单的图卷积神经网络,数学定义为:
构造过程可以分为5步:
(1)向邻接矩阵添加自循环边。
(2)对节点表征做线性变换。
(3)计算归一化系统。
(4)归一化邻接节点的表征。
(5)聚合邻接节点表征。
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# Step 1 向邻接矩阵添加自循环边
denoted by fill_value.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 对节点表征做线性变换.
x = self.lin(x)
# Step 3: 计算归一化系统
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 5: 聚合邻接节点表征.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: 归一化邻接节点的表征.
return norm.view(-1, 1) * x_j