图神经网络学习2

1、消息传递原理

为节点生成节点表征(Node Representation)是图计算任务成功的关键,我们要利用神经网络来学习节点表征。消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。消息传递范式因为简单、强大的特性,于是被人们广泛地使用。遵循消息传递范式的图神经网络被称为消息传递图神经网络。
具体来说就是:
1)首先从邻居获取信息:计算红色节点周围的四个邻居节点的消息总和。
2)对获得的信息加以利用:将获得的消息与(k-1)时刻红色节点本身的表示组合起来,计算得到k时刻的红色节点表示。总的来说也就是利用结点周边的特征来获取更有用的特征
消息传递图神经网络遵循上述的“聚合邻接节点信息来更新中心节点信息的过程”,来生成节点表征。用 x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示 ( k − 1 ) (k-1) (k1)层中节点 i i i的节点表征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD 表示从节点 j j j到节点 i i i的边的属性,消息传递图神经网络可以描述为
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
其中 □ \square 表示可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数。具有排列不变性的函数有,sum()函数、mean()函数和max()函数。 γ \gamma γ ϕ \phi ϕ表示可微分的函数,如MLPs(多层感知器)。此处内容来源于CREATING MESSAGE PASSING NETWORKS

2、MessagePassing基类的运行流程。

1)定义message方法,为各条边创建要传递给节点 i i i的消息,即实现 ϕ \phi ϕ函数。
2)定义aggregation 函数,将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum, meanmax
3)定义update函数,为每个节点 i ∈ V i \in \mathcal{V} iV更新节点表征,即实现 γ \gamma γ函数。

3、复现一个一层的图神经网络的构造,总结通过继承MessagePassing基类来构造自己的图神经网络类的规范

3.1

我们以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来实现一个简单的图神经网络。

GCNConv的数学定义为
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=jN(i){i}deg(i) deg(j) 1(Θxj(k1)),
其中,邻接节点的表征 x j ( k − 1 ) \mathbf{x}_j^{(k-1)} xj(k1)首先通过与权重矩阵 Θ \mathbf{\Theta} Θ相乘进行变换,然后按端点的度 deg ⁡ ( i ) , deg ⁡ ( j ) \deg(i), \deg(j) deg(i),deg(j)进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。
  2. 对节点表征做线性转换。
  3. 计算归一化系数。
  4. 归一化邻接节点的节点表征。
  5. 将相邻节点表征相加("求和 "聚合)。

步骤1-3通常是在消息传递发生之前计算的。步骤4-5可以使用MessagePassing基类轻松处理。该层的全部实现如下所示。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息从源节点传播到目标节点
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

本文主要内容来自datawhale开源课程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值