1. 提出MessagePassing的目的
MessagePassing是图神经网络(Graph Neural Networks,GNNs)的一个基础组件,它被设计用来处理图形数据的问题。在图形数据中,数据点(节点)之间的关系(边)是非常重要的信息。MessagePassing通过在节点之间传递和聚合信息,使得每个节点都能获取其邻居节点的信息,从而更好地理解图形的结构和特性。
具体来说,MessagePassing的工作方式是这样的:对于每个节点,它会收集其所有邻居节点的信息(这个过程称为消息传递),然后将这些信息聚合起来(这个过程称为消息聚合)。这样,每个节点都能获取到其邻居节点的信息,从而更好地理解图形的结构和特性。
在许多图形相关的任务中,如社交网络分析、分子结构预测、推荐系统等,MessagePassing都发挥了重要的作用。
2. MessagePassing基类解析
用户自定义算子的时候,需要继承MessagePassing基类并重写propagate函数、message函数和update函数。
在图神经网络中,propagate、message、aggregate和update函数是实现信息传递(Message Passing)机制的关键部分。
propagate函数
这是信息传递过程的主要驱动函数。它负责调用message、aggregate和update函数,并将结果传递给下一层。propagate函数通常会接收图的边索引(edge_index)和节点特征(node features)作为输入,然后通过message函数计算出每条边的消息,接着通过aggregate函数聚合这些消息,最后通过update函数更新每个节点的特征。
def propagate(self, edge_index, size=None, **kwargs):
message函数
这个函数负责计算每条边的消息。它通常会接收源节点和目标节点的特征作为输入,然后计算出一个消息。这个消息通常是源节点和目标节点特征的函数。
def message(self, x_j: Tensor) -> Tensor:
aggregate函数
这个函数负责聚合每个节点的所有消息。它通常会接收一个节点的所有邻居节点的消息作为输入,然后计算出一个聚合的消息。这个聚合的消息通常是所有邻居节点消息的函数。
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor:
update函数
这个函数负责更新每个节点的特征。它通常会接收一个节点的旧特征和该节点所有邻居的消息的聚合(通过aggregate函数实现)作为输入,然后计算出一个新的特征。
def update(self, inputs: Tensor) -> Tensor:
这四个函数一起实现了图神经网络的信息传递机制,使得每个节点都能获取其邻居节点的信息,从而更好地理解图形的结构和特性。
forward函数
MessagePassing是继承自torch.nn.Module的,所以推理时会调用forward函数,自然也要重写forward函数。
在forward函数中,通常会定义图神经网络层的输入和输出,以及调用propagate函数。
propagate函数是信息传递的主要驱动函数,它会调用message、aggregate和update函数,并将结果传递给下一层。
这个调用栈是典型的图神经网络的信息传递过程,但具体的实现可能会根据不同的图神经网络模型有所不同。
继承MessagePassing基类定义自定义GNN的层
例子1:GCNConv
Steps 1-3 are typically computed before message passing takes place. Steps 4-5 can be easily processed using the MessagePassing base class.
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
out = self.propagate(edge_index, x=x, norm=norm)
# Step 6: Apply a final bias vector.
out += self.bias
return out
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
conv = GCNConv(16, 32)
x = conv(x, edge_index)
例子2:EdgeConv
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
# x_i has shape [E, in_channels]
# x_j has shape [E, in_channels]
tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels]
return self.mlp(tmp)
from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv):
def __init__(self, in_channels, out_channels, k=6):
super().__init__(in_channels, out_channels)
self.k = k
def forward(self, x, batch=None):
edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
return super().forward(x, edge_index)
conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)