图神经网络的消息传递范式
图神经网络消息传播范式包涵以下3个基本步骤。1)邻接节点信息变换。2)邻接节点信息聚合到中心节点。3)聚合信息变换。如下图所示:
在整个消息的传递过程中,可以将这图神经网络传播的范式分为2个部分:第一部分是消息的表达。一个节点的消息与邻居(一跳)节点有关,并且与聚合消息的方式有关。第二部分为消息的动态传递。即从一跳怎么变成多跳的问题。重复第一部分的思想,使其串联起来就完成了整个图的消息传递过程。
PyG中的MessagePassing类
实现图神经网络的消息传播可以通过继承MessagePassing类来实现。它能实现消息的自动传播。了解其中图神经网络消息传播的范式相关步骤即可:message、aggregate、update函数模块。
class MessagePassing(torch.nn.Module):
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None]
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.message_and_aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
['message', 'aggregate', 'update']).difference(self.special_args)
self.__fused_user_args__ = self.inspector.keys(
['message_and_aggregate', 'update']).difference(self.special_args)
# Support for "fused" message passing.
self.fuse = self.inspector.implements('message_and_aggregate')
# Support for GNNExplainer.
self.__explain__ = False
self.__edge_mask__ = None
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
r"""The initial call to start propagating messages.
Args:
edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a
:obj:`torch_sparse.SparseTensor` that defines the underlying
graph connectivity/message passing flow.
:obj:`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its
shape must be defined as :obj:`[2, num_messages]`, where
messages from nodes in :obj:`edge_index[0]` are sent to
nodes in :obj:`edge_index[1]`
(in case :obj:`flow="source_to_target"`).
If :obj:`edge_index` is of type
:obj:`torch_sparse.SparseTensor`, its sparse indices
:obj:`(row, col)` should relate to :obj:`row = edge_index[1]`
and :obj:`col = edge_index[0]`.
The major difference between both formats is that we need to
input the *transposed* sparse adjacency matrix into
:func:`propagate`.
size (tuple, optional): The size :obj:`(N, M)` of the assignment
matrix in case :obj:`edge_index` is a :obj:`LongTensor`.
If set to :obj:`None`, the size will be automatically inferred
and assumed to be quadratic.
This argument is ignored in case :obj:`edge_index` is a
:obj:`torch_sparse.SparseTensor`. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct and
aggregate messages, and to update node embeddings.
"""
size = self.__check_input__(edge_index, size)
# Run "fused" message and aggregation (if applicable).
if (isinstance(edge_index, SparseTensor) and self.fuse
and not self.__explain__):
coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
size, kwargs)
msg_aggr_kwargs = self.inspector.distribute(
'message_and_aggregate', coll_dict)
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
# Otherwise, run both functions in separation.
elif isinstance(edge_index, Tensor) or not self.fuse:
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
# For `GNNExplainer`, we require a separate message and aggregate
# procedure since this allows us to inject the `edge_mask` into the
# message passing computation scheme.
if self.__explain__:
edge_mask = self.__edge_mask__.sigmoid()
# Some ops add self-loops to `edge_index`. We need to do the
# same for `edge_mask` (but do not train those).
if out.size(self.node_dim) != edge_mask.size(0):
loop = edge_mask.new_ones(size[0])
edge_mask = torch.cat([edge_mask, loop], dim=0)
assert out.size(self.node_dim) == edge_mask.size(0)
out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
[docs] def message(self, x_j: Tensor) -> Tensor:
r"""Constructs messages from node :math:`j` to node :math:`i`
in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in
:obj:`edge_index`.
This function can take any argument as input which was initially
passed to :meth:`propagate`.
Furthermore, tensors passed to :meth:`propagate` can be mapped to the
respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
:obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
"""
return x_j
[docs] def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
r"""Aggregates messages from neighbors as
:math:`\square_{j \in \mathcal{N}(i)}`.
Takes in the output of message computation as first argument and any
argument which was initially passed to :meth:`propagate`.
By default, this function will delegate its call to scatter functions
that support "add", "mean" and "max" operations as specified in
:meth:`__init__` by the :obj:`aggr` argument.
"""
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
作业
MessagePassing基类来构造自己的图神经网络类的规范。
选AGNNConv为自己的图神经网络类。
AGNNConv消息传播路径如下2步。
第一步:
X
′
=
P
X
\mathbf{X}^{\prime} = \mathbf{P} \mathbf{X}
X′=PX
第二步:
propagation matrix P。
P
i
,
j
=
exp
(
β
⋅
cos
(
x
i
,
x
j
)
)
∑
k
∈
N
(
i
)
∪
{
i
}
exp
(
β
⋅
cos
(
x
i
,
x
k
)
)
P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_k))}
Pi,j=∑k∈N(i)∪{i}exp(β⋅cos(xi,xk))exp(β⋅cos(xi,xj))
继承MessagePassing类部分:
class AGNNConv(MessagePassing):
def __init__(self, requires_grad, add_self_loops,
**kwargs):
kwargs.setdefault('aggr', 'add')
super(AGNNConv, self).__init__(**kwargs)
self.requires_grad = requires_grad
self.add_self_loops = add_self_loops
消息整合部分:
def message(self, x_j, x_norm_i, x_norm_j,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1)
alpha = softmax(alpha, index, ptr, size_i)
return x_j * alpha.view(-1, 1)
消息传播更新部分:
def forward(self, x, edge_index):
""""""
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index,
num_nodes=x.size(self.node_dim))
elif isinstance(edge_index, SparseTensor):
edge_index = set_diag(edge_index)
x_norm = F.normalize(x, p=2., dim=-1)
return self.propagate(edge_index, x=x, x_norm=x_norm, size=None)
参考资料
[1] https://pytorch-geometric.readthedocs.io/en/latest/index.html
[2] https://github.com/datawhalechina/team-learning-nlp