图神经网络GNN(二): 消息传递范式与简单GNN模型初识

1. 写在前面

这个系列整理的关于GNN的相关基础知识, 图深度学习是一个新兴的研究领域,将深度学习与图数据连接了起来,推动现实中图预测应用的发展。 之前一直想接触这一块内容,但总找不到能入门的好方法,而这次正好Datawhale有组队学习课,有大佬亲自带队学习入门,不犹豫,走起(感谢组织)。所以这个系列是参加GNN组队学习的相关知识沉淀, 希望能对GNN有一个好的入门吧 😉

这篇文章是消息传递范式, 上一篇文章里说为节点生成节点表征(Node Representation)是图计算任务成功的关键, 而图深度学习呢,是用神经网络来学习节点表征, 而学习节点表征,就难免需要考虑它的邻接节点的信息,这个过程中,又难免避开节点与节点之间的消息传递。

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式。它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。消息传递范式因为简单、强大的特性,于是被人们广泛地使用。遵循消息传递范式的图神经网络被称为消息传递图神经网络。 这篇文章首先介绍消息传递方式,然后通过代码的方式,构建一个GNN,来看看究竟节点与节点之间怎么个消息传递了,来完成自身更新的。这里面的重点是GNN的前向转播过程细节

内容大纲

  • 白话+数学语言之消息范式 – 基本介绍
  • 代码语言之消息传递方式 – MessagePassing基类分析

Ok, let’s go!

2. 白话+数学语言之消息范式 – 基本介绍

2.1 基于消息范式的中心节点更新过程

首先,看下基于消息传递范式的聚合邻接节点信息来更新中心节点信息的过程,这句话应该还是比较好理解的, 就是考虑它邻居节点来完成它自身节点的更新,基于下面这个图看下是如何完成自身节点更新的

在这里插入图片描述
宏观的信息更新过程如下:

  1. 图中黄色方框部分展示的是一次邻居节点信息传递到中心节点的过程:B节点的邻接节点(A,C)的信息经过变换聚合到B节点,接着B节点信息与邻居节点聚合信息一起经过变换得到B节点的新的节点信息。同时,分别如红色和绿色方框部分所示,遵循同样的过程,C、D节点的信息也被更新。实际上,同样的过程在所有节点上都进行了一遍,所有节点的信息都更新了一遍。
  2. 这样的“邻居节点信息传递到中心节点的过程”会进行多次。如图中蓝色方框部分所示,A节点的邻接节点(B,C,D)的已经发生过一次更新的节点信息,经过变换、聚合、再变换产生了A节点第二次更新的节点信息。多次更新后的节点信息就作为节点表征

2.2 消息传递图神经网络

消息传递图神经网络遵循上述的“聚合邻接节点信息来更新中心节点信息的过程”,来生成节点表征。

x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示 ( k − 1 ) (k-1) (k1)层中节点 i i i的节点表征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD 表示从节点 j j j到节点 i i i的边的属性,消息传递图神经网络可以描述为
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
其中 □ \square 表示可微分的、具有排列不变性(函数输出结果与输入参数的排列无关)的函数。具有排列不变性的函数有,和函数、均值函数和最大值函数。 γ \gamma γ ϕ \phi ϕ表示可微分的函数,如MLPs(多层感知器)。

这里的这个公式看上去挺吓人的, 但就是定义了消息传递图神经网络的节点更新方式而已,算是上面图中节点更新方式的数学表达。就类似于神经网络里面的前向传播公式, 这个公式先大体看懂逻辑即可,就是怎么个算法。

根据上面的那个消息传递图,目前给我的一个直观感觉, 把上面图里面的节点更新过程用数学公式翻译了一遍(简洁,易懂性)。第 k k k层中节点 i i i的更新,会和它前一次的状态,以及前一层中,与它相邻的节点的状态有关。

就比如上面图里面的B节点吧, B节点的新一次更新

  • 首先是与B相连的A,C节点信息聚合到B节点,此时,如果翻译成数学公式的话 ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) \phi^{(k)}(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}) ϕ(k)(xi(k1),xj(k1),ej,i), 而聚合到B节点之后, 不能直接拿来就用,还得再取其精华,去其槽粕,于是乎就得到了 □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)
  • 接着B节点信息与邻居节点聚合信息一起经过变换得到B节点的新的节点信息, 翻译成数学语言: γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))
    γ 和 ϕ \gamma和\phi γϕ函数,是增加了非线性变换,让信息表达更加丰富而已。

这样, 这个公式就梳理清楚了吧。

这里文档中,还给出了几个注,算是本次组队学习的约定吧:

  • 注(1):神经网络的生成节点表征的操作称为节点嵌入(Node Embedding),节点表征也可以称为节点嵌入。为了统一表述,我们规定节点嵌入只代指神经网络生成节点表征的操作

    这里隐约的感觉,节点表征,仿佛又是用embedding的方式去表示,而如何学习到这个embedding,就是图神经网络做的事情了, 这个是从词嵌入那里类比过来的,至于是不是,还得往下学习。

  • 注(2):未经过训练的图神经网络生成的节点表征还不是好的节点表征,好的节点表征可用于衡量节点之间的相似性。通过监督学习对图神经网络做很好的训练,图神经网络才可以生成好的节点表征。 到这里就更像词嵌入的方式了。

    看到这里,首先先联想到了词的表征,这个是NLP里面最基础的知识,对于一个词,我们如何去表示它呢? 词嵌入的方式目前是比较不错的一种方式,也就是通过一个embedding向量表示一个词,这个embedding向量表示,白话来讲,就是形成了一个空间(向量的每一维想象成一个坐标),把这个词通过每一维的坐标值映射到空间中的点上去,这样,从空间的角度就能看出词与词之间的相似性啥的。

    那么,这个向量是怎么得到的呢? 比较早且经典的方式,就是W2V。这个是通过神经网络去进行训练,首先,随机初始化embedding参数,然后通过中心词预测上下文或者上下文预测中心词的方式,去训练这个神经网络,最终当收敛,之后,就能得到较好的embedding参数了,而这个就能表示每个单词。为什么呢? 因为训练的时候是考虑了词与词之间的相关性的,而这里基于的一个假设,就是当词与词之间同时出现,且接近,往往认为这两个词之间关系比较近。

    而回忆完了这个过程,再去看这里的节点表征,是不是非常非常像呢?如果把单词用图上的节点表示,如果两个词同时出现,就连条线转成图,那么节点表征的生成过程,是不是上面这个过程了? 只不过这里的神经网络是图神经网络而已。哈哈,原来知识都是想通的呀!

  • 注(3),节点表征与节点属性的区分:遵循被广泛使用的约定,此次组队学习我们也约定,节点属性data.x是节点的第0层节点表征,第 h h h层的节点表征经过一次的节点间信息传递产生第 h + 1 h+1 h+1层的节点表征。不过,节点属性不单指data.x,广义上它就指节点的属性,如节点的度等。

由于之前,学过一点NLP的皮毛,然后就可以把那边的知识迁移到这里来帮助理解上面的内容,目前感觉这个原理并不是很复杂,接下来,就看看具体的实现细节了,比如,图神经网络是怎么构建出来的,长啥样, 节点embedding又是如何通过图神经网络训练出来的?

这些就需要从GNN的角度去深入学习了,组队学习里面的设计是直接从代码层面展开的, 那么展开之前,就需要先了解下上面的这个消息范式传递过程,又是从代码的语言是如何表示的?

3. 代码语言之消息传递方式

3.1 MessagePassing基类初步分析

关于MessagePassing的详细介绍,看官方文档

Pytorch Geometric(PyG)提供了MessagePassing基类,它封装了“消息传递”的运行流程。通过继承MessagePassing基类,可以方便地构造消息传递图神经网络。

构造一个最简单的消息传递图神经网络类,我们只需定义message()方法 ϕ \phi ϕupdate()方法 γ \gamma γ,以及使用的消息聚合方案(aggr="add"aggr="mean"aggr="max"。这一切是在以下方法的帮助下完成的:

  • MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)(对象初始化方法):
    • aggr:定义要使用的聚合方案(“add”、"mean "或 “max”);
    • flow:定义消息传递的流向("source_to_target "或 “target_to_source”);
    • node_dim:定义沿着哪个维度传播,默认值为-2,也就是节点表征张量(Tensor)的哪一个维度是节点维度。节点表征张量x形状为[num_nodes, num_features],其第0维度(也是第-2维度)是节点维度,其第1维度(也是第-1维度)是节点表征维度,所以我们可以设置node_dim=-2
    • 注:MessagePassing(……)等同于MessagePassing.__init__(……)
  • MessagePassing.propagate(edge_index, size=None, **kwargs)
    • 开始传递消息的起始调用,在此方法中messageupdate等方法被调用。
    • 它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。
    • 请注意,propagate()不仅限于基于形状为[N, N]的对称邻接矩阵进行“消息传递过程”。基于非对称的邻接矩阵进行消息传递(当图为二部图时),需要传递参数size=(N, M)
    • 如果设置size=None,则认为邻接矩阵是对称的。
  • MessagePassing.message(...)
    • 首先确定要给节点 i i i传递消息的边的集合:
      • 如果flow="source_to_target",则是 ( j , i ) ∈ E (j,i) \in \mathcal{E} (j,i)E的边的集合;
      • 如果flow="target_to_source",则是 ( i , j ) ∈ E (i,j) \in \mathcal{E} (i,j)E的边的集合。
    • 接着为各条边创建要传递给节点 i i i的消息,即实现 ϕ \phi ϕ函数。
    • MessagePassing.message(...)方法可以接收传递给MessagePassing.propagate(edge_index, size=None, **kwargs)方法的所有参数,我们在message()方法的参数列表里定义要接收的参数,例如我们要接收x,y,z参数,则我们应定义message(x,y,z)方法。
    • 传递给propagate()方法的参数,如果是节点的属性的话,可以被拆分成属于中心节点的部分和属于邻接节点的部分,只需在变量名后面加上_i_j。例如,我们自己定义的meassage方法包含参数x_i,那么首先propagate()方法将节点表征拆分成中心节点表征和邻接节点表征,接着propagate()方法调用message方法并传递中心节点表征给参数x_i。而如果我们自己定义的meassage方法包含参数x_j,那么propagate()方法会传递邻接节点表征给参数x_j
    • 我们用 i i i表示“消息传递”中的中心节点,用 j j j表示“消息传递”中的邻接节点。
  • MessagePassing.aggregate(...):aggregate(聚合)
    • 将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum, meanmax
  • MessagePassing.message_and_aggregate(...)
    • 在一些场景里,邻接节点信息变换和邻接节点信息聚合这两项操作可以融合在一起,那么我们可以在此方法里定义这两项操作,从而让程序运行更加高效。
  • MessagePassing.update(aggr_out, ...):
    • 为每个节点 i ∈ V i \in \mathcal{V} iV更新节点表征,即实现 γ \gamma γ函数。此方法以aggregate方法的输出为第一个参数,并接收所有传递给propagate()方法的参数。

看到上面这些,我是一脸懵逼,不过我感觉是正常的, 毕竟这些都是第一次接触, 第一次接触某个类,难免会有这种感觉,并且也肯定记不住。 不过, 大体的传递逻辑得知道,也就是节点在哪里初始化, 消息传递在哪里,消息聚合在哪里等,说白了,就是得和前面的数学公式相应的部分对应上,否则,我们是不知道怎么编写代码实现图神经网络的。 这里我先把数学公式摘过来,然后说说我对上面代码的体会。
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
这个节点更新过程数学公式,估计经过上面的解释了然了,那么在代码中,我们应该如何实现这个过程呢? 这里就用到了PyG包的一个基类MessagePassing。

  • 首先它包含了一个构造函数,在这里,我们完成传递前的初始化工作,比如信息传递维度啊, 这里由于是节点与节点的传递,所以维度是0(-2), 传递流向啊,聚合方案等。
  • 其次,是包括了5个方法, 就对应了上面数学公式的实现过程了
    • MessagePassing.message(...): 这个实现的是 ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) ϕ(k)(xi(k1),xj(k1),ej,i), 汇聚相邻节点的信息到中心节点。
      所以在这里面,首先得定义节点 i i i传递消息的边的集合,这里可以从起点指向终点,也可以从终点指向起点。接着,要在里面实现 ϕ \phi ϕ函数的逻辑,也就是相邻节点的消息是如何传递到当前中心节点上去的, 注意,目前只是实现逻辑,并不是真正的调用。 至于上面解释的该函数后面的朦胧细节,等下面看了源码之后,自然就敞亮了。先梳理通逻辑。
    • MessagePassing.message_and_aggregate(...): 实现的是 □ j ∈ N ( i ) \square_{j \in \mathcal{N}(i)} jN(i),从源节点传递过来的消息聚合在目标节点上
    • MessagePassing.message_and_aggregate(...):这个就是把上面两个逻辑合在一块实现,比较容易理解。
    • MessagePassing.update(aggr_out, ...): 实现的是 γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)): 也就是自身的消息和邻近节点消息的合并计算逻辑,实现自我更新。
    • MessagePassing.propagate(edge_index, size=None, **kwargs): 这个我感觉应该最后再说,因为上面这些都只是消息传递的逻辑实现, 而这个才是真正的调用上面的方法,进行传递的过程, 就类似于神经网络里面的forward()函数。实现了消息传递过程,而这里面,是调用了上面的方法而已。

这样,整个逻辑就梳理清楚了, 我看代码的习惯就是先把整个逻辑梳理好,大致上知道每个函数是干啥的,不用过分先陷入到细节中,并且仅仅从上面的描述里面,我们也很难陷入细节,这个时候心态要好哈哈,不要刚开始看,就这块怎么一点也看不懂,然后又去官网看,加上英文,更是一脸懵逼,结果就得出了哦, 原来我不适合学GNN啊,然后"立即推,放弃学习,打开抖音,刷个小视频压压惊"。哈哈,这个思路就是我之前学习的展现。 但现在肯定不会了呀, 我们没有上帝的视角,哪能刚接触一个概念,就立马哦,懂了,后面应该是这样这样子…。所以,不懂才能显示出我们是正常人,这是和正常人一块愉快玩耍的前提,再稍微有点耐心,往下看看,说不定就开始悟了呢哈哈。

有点扯远了, 抓紧回来,梳理完了整个基类MessagePassing 之后, 我们大体上知道了每个函数到底在干嘛,如果我们写的时候,也大致上有点感觉,应该在哪块地方,实现一个什么样的逻辑等。 那么接下来,就是看一个例子了嘛,在这里,不就能把前面的都打通了? 所以有时候只需要坚持那么一点点而已啦 😉。

下面开始代码有些多, 戴好安全帽!!!

3.2 MessagePassing子类实例

这里主要是说,如何具体实现一个GCN层的前向传播逻辑(简单图神经网络),用人话讲,就是上面数学公式的具体细节了,或者说中心节点更新的代码细节。 具体的可参考Implementing the GCN Layer官方文档

我们以继承MessagePassing基类的GCNConv类为例,学习如何通过继承MessagePassing基类来实现一个简单的图神经网络。

GCNConv的数学定义为
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=jN(i){i}deg(i) deg(j) 1(Θxj(k1)),
其中,邻接节点的表征 x j ( k − 1 ) \mathbf{x}_j^{(k-1)} xj(k1)首先通过与权重矩阵 Θ \mathbf{\Theta} Θ相乘进行变换,然后按端点的度 deg ⁡ ( i ) , deg ⁡ ( j ) \deg(i), \deg(j) deg(i),deg(j)进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。
  2. 对节点表征做线性转换。
  3. 计算归一化系数。
  4. 归一化邻接节点的节点表征。
  5. 将相邻节点表征相加("求和 "聚合)。

步骤1-3通常是在消息传递发生之前计算的。步骤4-5可以使用MessagePassing基类轻松处理。该层的全部实现如下所示。

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

from torch_geometric.datasets import Planetoid

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels):
        """
        初始化: 这里在基类的基础上定义了一个Linear层
        :param in_channels: 输入特征个数
        :param out_channels: 输出特征个数,Linear的神经单元数量
        """
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        """
        前向传播逻辑
        :param x: 节点表示矩阵, [N, in_channels], N个节点, in_channles个特恒
        :param edge_index: 边的信息,[2, E], 2行, E列,每一列表示一条边,索引指向(起始点索引->终止点索引)
        :return:
        """
        # 1. 增加邻接矩阵的自循环, 这个函数做的事情,就是每个节点i到i再增加一条边,把索引填到最终的edge_index中
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # _那个是权重,这里无权,所以不要这个,看下源码就明白了
        # print(edge_index.shape)
        
        # 2. 特征矩阵的线性变换
        x = self.lin(x)   # [N, out_channels]

        # 3. 计算归一化系数
        source, target = edge_index
        deg = degree(target, x.size(0), dtype=x.dtype)  # 这个函数在计算每个节点的度
        #print(deg)         # 每个节点的度  [4., 4., 6.,  ..., 2., 5., 5.]
        #print(deg.shape)   # 2708
        
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[source] * deg_inv_sqrt[target]   # [E]

        # 4-5 开始前向转播
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        """

        :param x_j: [E, out_channels]
        :param norm: 归一化的节点特征
        :return:
        """
        return norm.view(-1, 1) * x_j   # [E, out_channels], 这里是经历过广播

if __name__ == '__main__':
    dataset = Planetoid(root='../dataset/Cora', name='Cora')
    data = dataset[0]     # 图结构拿到

    net = GCNConv(data.num_features, 64)
    h_nodes = net(data.x, data.edge_index)   # x [2708, 1433], data.edge_index [2, 10556]
    print(h_nodes.shape)   # [2708, 64]

GCNConv继承了MessagePassing并以"求和"作为领域节点信息聚合方式。该层的所有逻辑都发生在其forward()方法中。在这里,我们首先使用torch_geometric.utils.add_self_loops()函数向我们的边索引添加自循环边(步骤1),以及通过调用torch.nn.Linear实例对节点表征进行线性变换(步骤2)。propagate()方法也在forward方法中被调用,propagate()方法被调用后节点间的信息传递开始执行。

归一化系数是由每个节点的节点度得出的,它被转换为每条边的节点度。结果被保存在形状为[num_edges,]的变量norm中(步骤3)。

message()方法中,我们需要通过norm对邻接节点表征x_j进行归一化处理。

通过以上内容的学习,我们便掌握了创建一个仅包含一次“消息传递过程”的图神经网络的方法, 而通过串联多个这样的简单图神经网络,我们就可以构造复杂的图神经网络模型

上面这个感觉是一个重点内容,也是学习后面的基础,所以必须要把这个梳理明白,仅仅看文档里面的描述是不够的,所以,下面我又自己梳理了一遍细节,所谓细节,就是争取弄懂每一行代码在干啥了

首先,这里用的数据集依然是上一篇文章中提到的那个数据集,具体描述在上一篇,这个数据集有2708篇文档,每篇文档用1433个特征来描述, 用10556条边,这是基本情况。

  1. 导入数据集

  2. 建立GCN网络,这个网络接收的参数是输入特征数量和输出特征数量,中间的映射关系就是Linear层, 这里的输出特征用了64,这个过程是Linear的1433到64特征的非线性变换
    在这里插入图片描述

  3. GCN网络的初始化
    这里是完成GCN的初始化操作,更准确的说是父类的初始化操作,这里面会定义聚合的操作,检查子类是否实现了message_and_aggregate()方法,并将检查结果赋值给fuse属性。这个会影响信息更新的前向传播过程。

    class MessagePassing(torch.nn.Module):
    			def __init__(self, aggr: Optional[str] = "add", flow: str = "source_to_target", node_dim: int = -2):
    		        super(MessagePassing, self).__init__()
    				# 此处省略n行代码
    		        # Support for "fused" message passing.
    		        self.fuse = self.inspector.implements('message_and_aggregate')
    				# 此处省略n行代码
    
  4. GCN网络的前向传播,这里是核心点,首先接收的参数是数据x和边的信息矩阵edge_index,这两个分别表示节点的特征和节点之间的连接情况,这个里面的计算是核心,分了5步进行。

    1. 创建自环边 add_self_loops
      这个函数干的一个事情就是节点到节点自身的索引也加到边的信息里面去,相当于在原来的基础上,加了自身到自身的边索引,拼到了edge_index,此时这个边的维度变成[2,10556+2708], 2708是节点个数,那么这个是维度是显然了吧。例如,从[[0,1,1,2],[1,0,2,1]]变成了[[0,1,1,2,0,1,2],[1,0,2,1,0,1,2]]

    2. 接下来,x自身的线性变换 linear, 这个不用解释,过了一个linear层

    3. 计算归一化系数,也就是各个节点的度

      1. 首先source, target拿到了edge_index的第一行和第二行数据, 也就是起点和终点索引,所以这个地方文档的row,col感觉符号标识的有点不好理解。每个都是13264长度的列表。看下图就明白了。
        在这里插入图片描述
      2. 计算节点度的函数在degree里面, 具体细节不用管,传进终点索引去,传进节点个数,自动就能算出每个节点的度数来,当然,传source,也就是起点进去是一样的。这个是2708的大小,每个位置表示对应节点的度数
      3. 归一化操作及下面的相乘这样就非常容易理解了,这个最后得到的norm是13624的列表, 每个元素代表了相应边的source节点的度与target节点度开根号的乘积。
    4. 节点更新的前向传播过程,这里才是消息传递的重头戏, 只可惜,一句话这里就完事了,但仅仅这一句话,怎么能看清楚具体过程? 所以,这里我开启了断点调试。

      1. 首先,这里会进入self.propagate(edge_index, x=x, norm=norm)函数,这个就是父类MessagePassing的核心实现函数,这里接收的是边的连接信息,节点的特征信息以及节点的度的归一化结果。 “消息传递过程”是从propagate方法被调用开始执行的。 step into。

      2. 进去之后,源码有些复杂, 看重要的两块逻辑:

        def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
            	# 此处省略n行代码
                # Run "fused" message and aggregation (if applicable).
                if (isinstance(edge_index, SparseTensor) and self.fuse and not self.__explain__):
                    coll_dict = self.__collect__(self.__fused_user_args__, edge_index, size, kwargs)
        
                    msg_aggr_kwargs = self.inspector.distribute('message_and_aggregate', coll_dict)
                    out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
        
                    update_kwargs = self.inspector.distribute('update', coll_dict)
                    return self.update(out, **update_kwargs)
                # Otherwise, run both functions in separation.
                elif isinstance(edge_index, Tensor) or not self.fuse:
                    coll_dict = self.__collect__(self.__user_args__, edge_index, size, kwargs)
        
                    msg_kwargs = self.inspector.distribute('message', coll_dict)
                    out = self.message(**msg_kwargs)
            		# 此处省略n行代码
                    aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
                    out = self.aggregate(out, **aggr_kwargs)
        
                    update_kwargs = self.inspector.distribute('update', coll_dict)
                    return self.update(out, **update_kwargs)
        

        这里的**kwags是一个词典,主要包括x和norm。
        propagate()方法首先检查edge_index是否为SparseTensor类型以及是否子类实现了message_and_aggregate()方法,如是就执行子类的message_and_aggregate方法;否则依次执行子类的message(),aggregate(),update()三个方法。即这里的区别是我们到底想前两步是不是要放一块。这个例子中,是分别执行后面的三个函数。

        • 首先是msg_kwargs, 这个东西会得到一个字典, 包括x_j和norm,其中x_j[E, out_channels]的维度, 表示的是各个边的source节点的特征表示, 当然这里是自悟的, 找资料一看,还真是这样哈哈, 下面这个图比较详细的看清x_j
          在这里插入图片描述
          当然,这个图后来我仔细看了下源码, 这里说target节点的embedding并不是完全正确的, 这个得先通过指定的方向参数self.flow来判断

          • flow="source_to_target"时,节点edge_index[0]的信息将被传递到节点edge_index[1],此时我们汇总的应该是source信息,这时候,应该将上面source节点对应索引的embedding向量合并起来得到x_j
          • flow="target_to_source"时,节点edge_index[1]的信息将被传递到节点edge_index[0], 此时上面一行是target,下面一行是source,依然是将source对应的节点embedding合并得到x_j。不过,不影响我们先理解x_j本身,这个东西反正表示的是source节点处索引对应的embedding。


          这个地方如果不理解的话,后面聚合的时候就会懵逼,并且这里真的是源码里面有体现:

              def __collect__(self, args, edge_index, size, kwargs):
                  i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)  # j那边会发现始终都是source端
          
        • 再执行message函数, 这个step into之后,会发现跳进了我们自己实现的message函数,而这个返回了一个[E, out_channels]的矩阵,这个函数做的事情是度归一化的节点特征表示。

          def message(self, x_j, norm):
                  """
          
                  :param x_j: [E, out_channels]
                  :param norm: 归一化的节点特征
                  :return:
                  """
                  return norm.view(-1, 1) * x_j   # [E, out_channels], 这里是经历过广播
          
          

          norm是上面说的每一条边source节点的度与target节点度开根号的乘积, 是个[E,1]的, 而后面的x_j是target节点的embedding,是[E, out_channels]的,两者一乘,就实现了 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right) deg(i) deg(j) 1(Θxj(k1)), 而这,就代表着要传递的信息, 那么怎么传递呢?下面

        • 然后是aggr_kwargs,获取聚合需要的参数,这个东西也是个字典,'ptr’表示聚合策略,如果指定,就用指定的,如果没有,就用默认的sum. dim_size是维度,这里是节点数量

        • 执行aggregate函数,这个返回的是[2708,64]的矩阵。
          这个过程是这样, 对于图中每个节点,我们会从x_j中找到它自身的embedding以及指向它的邻居的embedding, 执行聚合操作(sum,mean, max, min)等, 得到融入邻居节点信息的embedding(节点表征)。公式的话是这样,注意求和符合下面的条件,节点 i i i以及节点 i i i的邻居,而这个邻居表示的是从邻居出发指向 i i i的那些节点。
          x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=jN(i){i}deg(i) deg(j) 1(Θxj(k1)),
          图的话是这样:
          在这里插入图片描述
          这里依然是感觉它这个里面像是target_to_source流向型的, 感觉下面应该是source节点,上面是target节点,那么这个就更加清楚了。就比如0那一个来说, 相当于流向0节点的邻居节点embedding的聚合之后,得到最终0位置那个节点的embedding,其实整个过程就干了个这样的事情而已。聚合之后,得到的结果,就是GCN的输出了。

        • 执行update函数,得到的就是聚合函数之后的结果了,这个就表示了信息传递之后的,各个节点特征表示。

这样, 前向传播的整个过程就梳理明白了,这里也知道了,所谓的消息传递,是把与之相邻的邻居节点的信息传到该节点上去, 最简单的方式,就是融入邻居节点的embedding信息。这里非常感谢下面的第二篇链接。

这里明白了之后,下面的内容就非常简单了,就是我们自定义message, aggregate以及update函数了。

3.3 自定义核心函数

3.3.1 message方法的重写 - 加更多的节点特征进去参与运算

前面介绍了,传递给propagate()方法的参数,如果是节点的属性的话,可以被拆分成属于中心节点的部分和属于邻接节点的部分,只需在变量名后面加上_i_j

现在我们有一个额外的节点属性,节点的度deg,我们希望meassge方法还能接收中心节点的度,我们对前面GCNConvmessage方法进行改造得到新的GCNConv类:

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息从源节点传播到目标节点
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_i):
        # x_j has shape [E, out_channels]
        # deg_i has shape [E, 1]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j * deg_i

这里面的重点,是forward里面的最后一行,会发现需要传进节点的度参数,这个是[num_nodes, 1], 这里必须注意:

若一个数据可以被拆分成属于中心节点的部分和属于邻接节点的部分,其形状必须是[num_nodes, *],因此在上方代码的第29行,我们执行了deg.view((-1, 1))操作,使得数据形状为[num_nodes, 1],然后才将数据传给propagate()方法。

而具体用这个特征的时候,则是在message里面写怎么用。 这里的deg_i表示的就是拿到了边里面target侧所有节点的度大小。把这个也融入到了embedding里面去, 融入的方式是乘积。

这个方式可以让我们加更多的节点特征进去参与运算。当然,这里我想了下, 融入特征的时候,不仅可以乘进去,因为乘到embedding里面去,感觉没有道理呀,这个embedding和节点的度八竿子打不着,为啥要乘呢? 难度不能再搞出一个维度来,也就是拼接到embedding的那个维度上,使得返回结果是[E, out_channels+1]的? 所以接下来,我做了下面的尝试, 也就是修改message为这样:

 def message(self, x_j, norm, deg_j):
    """
    :param x_j: [E, out_channels]
    :param x_i: [E, 1]
    :param norm: 归一化的节点特征
    :return:
    """
    #print(deg_j.shape)
    #return norm.view(-1, 1) * x_j * deg_j   # [E, out_channels]
    x_j = norm.view(-1,1) * x_j
    return torch.cat((x_j, degmm_j), dim=1)  # [E, out_channels+1]

这样,就能够把节点的度数特征拼接到之前的embedding维度上去了,而不是无脑的直接乘到里面。 而这里接收参数的时候一定要注意, 名字对应起来才可以:
在这里插入图片描述
这样,就完成了加特征的操作,这个操作可是很重要的,一般节点往往会加入一些其他的特征,或者将一部分特征通过线性层,另一部分通过非线性,另一部分保持不变,最后拼接, 而这些只有真正懂了前向传播的原理之后,才能处理起来得心应手。 或者在propagate之前,先完成向量拼接,然后再给propagate传入x, 或者在message里面进行修改等。 不同的场景不一样, 但修改的地方是一样的。

再比如, 可能我相邻节点信息的重要性等对我当前节点可能不一样, 所以在后面汇合的时候,我不想直接所有相邻接点的embedding相加, 而是想先通过一个神经网络学习一个权重出来, 这个可以直接过神经网络了,或者是与当前节点embedding想的相似度求出一个权重来表示每个邻居节点embedding的重要性, 那么我该如何做呢?

由于我看了下下面的汇合方法,这个可能是提供了无脑相加,或者是求平均,最大等, 不太好改,所以我们如果想实现上面的这种需求,依然是需要在forward方法里面改, 也就是在propatate之前,实现这块的逻辑,修改x向量本身也就是x=self.linear(x)其实非常灵活,这里的linear可以换成任意的mlp,所以即使加attention也没有任何问题感觉

当然我目前是感觉哈,如果这块不行, 那么就从message里面实现这样的逻辑,因为message里面,最终得到的是每条边source端节点的embedding或者说最终表征, 而下面是相邻节点表征的聚合,那么这个embedding的得出逻辑,我们其实是可以从message里面修改的,上面的两个就是最好的例子。

但如果加attention的话,目前感觉不如从前向传播那里加容易,当然,如果这个权重,是依据当前边的embedding与相邻节点embedding相似度得到的,那么显然,这个需要基于边的索引情况,先找到当前点的所有邻边, 再拿到对应embedding,与当前边embedding求出相似度来,乘到对应边上去,这样就能得到加权的embedding了,所以这个,得从message里面去实现。这个得有边的信息,可以在初始化的时候定义一个self.edge_index接收下吧。这样在message里面就能算了。

3.3.2 aggregate方法的覆写

在前面的例子的基础上,我们增加如下的aggregate方法。通过观察运行结果我们可以看到,我们覆写的aggregate方法被调用,同时在super(GCNConv, self).__init__(aggr='add')中传递给aggr参数的值被存储到了self.aggr属性中。 这个感觉没有啥大作用, ptr这个

def aggregate(self, inputs, index, ptr, dim_size):
    print('self.aggr:', self.aggr)
    print("`aggregate` is called")
    return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

这里的inputs就是message里面的输出, index边的信息,为了得到当前点邻居节点的embedding, 这里的ptr保证的是沿着正确的维度传递,这里是沿着x的-2维度,也就是样本的维度传递的。这个具体的没怎么看明白。

3.3.3 message_and_aggregate方法的覆写

在一些案例中,“消息传递”与“消息聚合”可以融合在一起。对于这种情况,我们可以覆写message_and_aggregate方法,在message_and_aggregate方法中一块实现“消息传递”与“消息聚合”,这样能使程序的运行更加高效。

def message_and_aggregate(self, adj_t, x, norm):
    print('`message_and_aggregate` is called')
    # 没有实现真实的消息传递与消息聚合的操作

这里啥也没实现,也就是说message和aggregate的操作可以合并到一块写,但目前不知道这块咋操作,因为发现aggregate那里还是挺复杂的,写到一块就不知道怎么操作了。

运行程序后我们可以看到,虽然我们同时覆写了message方法和aggregate方法,然而只有message_and_aggregate方法被执行。

3.3.4 覆写update方法
 def update(self, inputs, deg):
    print(deg)
    return inputs

update方法接收聚合的输出作为第一个参数,此外还可以接收传递给propagate方法的任何参数。在上方的代码中,我们覆写的update方法接收了聚合的输出作为第一个参数,此外接收了传递给propagatedeg参数。

感觉这里面也能做些后续的处理工作,也就是聚合完成之后,得到每个节点的最终表示的时候, 或许还可以做些修正操作等。也就是改inputs, 或者拿这个inputs再去接神经网络,做别的运算等吧,也可以在这里融入一些其他的信息吧。 这些地方感觉就比较灵活了。

4. 小总

哇,到这里终于这篇文章结束了,这个花了我大约1天的时间,由于实习,发现实在是难以在工作日的时候搞这个,所以只能在休息日的时间做下,但收获还是很多的,至少对于这种图上融合节点信息有了一个更加深刻的了解,也对GNN的前向传播有了更加细致的了解,原理GNN的节点embedding可以这样算呀,学到了。

下面是文档里面的总结:

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到了不规则数据领域,实现了图与神经网络的连接。该范式包含这样三个步骤:(1)邻接节点信息变换、(2)邻接节点信息聚合到中心节点、(3)聚合信息变换。因为简单且强大的特性,消息传递范式现被人们广泛地使用。基于此范式,我们可以定义聚合邻接节点信息来生成中心节点表征的图神经网络。在PyG中,MessagePassing基类是所有基于消息传递范式的图神经网络的基类,它大大地方便了我们对图神经网络的构建。

另外就是通过和群里伙伴的一点讨论,引起了一点反思, 就是究竟要不要看那么细的问题,因为第一遍看的时候x_j以及聚合那块是没弄清楚是咋回事的,所以就向群里伙伴求助, 有伙伴说这地方不用管底层怎么实现,不用重复造轮子,知道怎么执行的就可以了。源码里已经写好了,不用再操心了,一般可以发挥的地方就是覆写代码的地方,造轮子比较复杂,对于入门者难度大,又不实用,所以就没介绍

在这里插入图片描述
当然我不是说伙伴们说的有问题哈, 只是从我自己的角度出发,我觉得这样学习还是有些危险的,虽然说是入门,但是重要的细节我觉得还是有必要弄清楚的,并且是需要弄的尽量清楚的。 就比如上面GNN信息传递的具体过程,如果这个我们自己都觉得,底层已经实现好了,不用我们操作,只会覆盖方法即可,如果我们自己都觉得不用研究太深入,我们目前入个门即可,那么我们就可能真的连门都进不去了,因为我们在实现目标的时候往往会打折扣,如果只是把目标定为入门,那么我们估计很难入门,所以这样的想法本身比较危险,学习知识虽然不是说每个点都知其所以然,但基础知识,我们还是有必要知其所以然的,这算是一种求知的态度。 毕竟我们是花了时间和精力的去学习这些东西的,但学完了之后,掌握个马马虎虎,用的时候又不敢放心用,还得重新找资料学习的话,那我们这次学习的意义在哪? 就为了知道"哦,这个问题我之前学过,我去复习复习?" 。并且这样,也不利于学习的串联性, 更没法达到看到节点嵌入,立即想到词嵌入,又立即串联到NLP的各个知识w2v->elmo-transformer->bert等, 而知识的串联性,更有利于我们加深知识本身的理解。 但这些的前提,都是需要我们学习每个新知识的时候,首先端正态度才行, 这个态度,对于我自己来讲,首先是有耐心,其次就是重要知识一定要知其所以然。实习之后,我更加体会到了时间的宝贵性,在工作日中,每天节奏太快,我发现根本就没法抽时间学习自己喜欢的知识, 光紧迫性任务都学不过来,到处是盲区,也可能是自身太菜了,掌握的东西太少。 所以这种课余学习机会我现在是非常珍惜的,这也是我为啥有些东西想尽量想明白的原因。这些东西预感就是未来一定会用到,所以现在开始,就想先把基础打好。

对于GNN信息传递过程,我觉得对于学GNN来讲应该是最底层的基础了,这个重要度就类似于手写神经网络的前向传播那样。 如果我们连这个东西都搞不明白,也不知道聚合是咋聚合的,message里面到底干了些啥? 我真的有些担心,面对具体应用时,我们能写出正确的覆盖方法来吗? 即使幸运的写出来了,但效果好不好,为啥好与不好,再怎么改? 到时候我们又可能脑海中憋了好久,最后答出个"玄学", 这个反正,在工业上或者公司中,我可不敢这么做。

虽然说从表面是看,是根据实际应用,覆盖message,聚合和update, 但是不知道这三个内部怎么配合的,又怎么写相应的实现逻辑,就比如我考虑的那个问题,每个邻居节点可能对于当前节点的重要程度不同,我想先根据邻居节点与当前节点相关性加个权重,然后再sum聚合起来,这个怎么改? 再比如,我想加节点新特征信息,但是不要无脑乘到embedding里面去,而是扩展出维度来,应该怎么改?

今天是由和伙伴们的讨论中,引出了自己的一些感想, 再次声明,我并没有一点说别人说的有问题,也并没有一点表达我自己认为的多么正确的意思,只是说,每个人所处环境不同,所思所想不同,这个是非常正常的,每个人都有自己的看法和态度,这个只是我自己的态度而已哈,并没有谁对谁错,因为随着经历的不同,我相信我后面的想法也会变得不一样, 但不停的记录自己每个时刻的态度和想法,对我来说,也是一件非常有意思的事情。

扯得有点远了,还是看看两个作业吧:

  1. 请总结MessagePassing基类的运行流程。
  2. 请复现一个一层的图神经网络的构造,总结通过继承MessagePassing基类来构造自己的图神经网络类的规范。

首先,第一个问题,也就是我说的最基础的问题了,整体运行流程,就是把相邻节点的信息融入到节点本身。具体实现的话,最根本的就是落在了三个函数上:

  • message函数中,针对每一条边,我们需要获得各个source节点的特征信息,因为说了要把相邻节点的信息融入到target节点嘛, 那么得先拿到邻居节点的信息呀,这个函数非常重要, 因为邻居节点的信息究竟怎么得到,在这里面可以进行适当的修改,比如加特征等
  • propagate函数中, 是我们通过message函数,已经能够获取到对于某个中心节点其邻居节点的特征表示, 那么我们怎么合成一个向量呢? 这里可以求和,求平均和最大值等,这样就把邻居节点信息汇总成了一个embedding, 这个就是该中心节点最终的embedding
  • update函数,这里接收上面的embedding,可以直接返回,也可以再做一些修正工作。

第二个问题,文档中已经给出了一个比较标准的图神经网络了,并且还详细写了如何用, 后面用的时候,其实就可以基于这个例子, 文档中给出的算是一个比较规范的范式了:

class GCNConv(MessagePassing):

    def __init__(self, in_channels, out_channels):
        """
        初始化: 这里在基类的基础上定义了一个Linear层
        :param in_channels: 输入特征个数
        :param out_channels: 输出特征个数,Linear的神经单元数量
        """
        # 初始化这个一定要符合这个
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
		
		# 这里可以是线性的,感觉也可以加MLP,或者更复杂的网络等,即使加个transformer,感觉也能跑通,就看业务上是否过的去了哈哈
        self.lin = torch.nn.Linear(in_channels, out_channels)
        
	# 具体计算逻辑还是在forward里面, 但消息传递在下面的progagate
    def forward(self, x, edge_index):
        """
        前向传播逻辑
        :param x: 节点表示矩阵, [N, in_channels], N个节点, in_channles个特恒
        :param edge_index: 边的信息,[2, E], 2行, E列,每一列表示一条边,索引指向(起始点索引->终止点索引)
        :return:
        """
        # 1. 增加邻接矩阵的自循环, 这个函数做的事情,就是每个节点i到i再增加一条边,把索引填到最终的edge_index中
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))  # _那个是权重,这里无权,所以不要这个,看下源码就明白了
        #print(edge_index)

        # 2. 特征矩阵的线性变换
        x = self.lin(x)   # [N, out_channels]

        # 3. 计算归一化系数, 这里的deg不知道是个啥东西了
        start, end = edge_index
        deg = degree(start, x.size(0), dtype=x.dtype)  # 这个函数在计算每个节点的度
        #print(deg)         # 每个节点的度  [4., 4., 6.,  ..., 2., 5., 5.]
        #print(deg.shape)   # 2708

        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[start] * deg_inv_sqrt[end]

        # 4-5 开始消息传递  在这个之前,上面其实x可以任意计算修改,特征也可以灵活加入,只要能传对参数
        print(deg.view((-1, 1)).shape)
        return self.propagate(edge_index, x=x, norm=norm, deg=deg.view((-1, 1)))

    def message(self, x_j, norm, deg_j):
        """
		这个是消息生成的具体逻辑,这里面依然是可以非常灵活的修改每条边source端点的embedding的计算逻辑
        :param x_j: [E, out_channels]
        :param x_i: [E, 1]
        :param norm: 归一化的节点特征
        :return:
        """
        #print(deg_j.shape)
        #return norm.view(-1, 1) * x_j * deg_j   # [E, out_channels]
        x_j = norm.view(-1,1) * x_j
        return torch.cat((x_j, deg_j), dim=1)
	
	# 这个感觉不好变了
	def aggregate(self, inputs, index, ptr, dim_size):
        print('self.aggr:', self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
	
	# 这个目前没弄明白具体怎么配合,源码有点难懂,由于时间原因,这个先不看
    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')
        # 没有实现真实的消息传递与消息聚合的操作
	
	# 这里面是最后结果产生的逻辑, 也可以在聚合的基础上进行修正
    def update(self, inputs, deg):
        print(deg)
        return inputs

if __name__ == '__main__':
    dataset = Planetoid(root='../dataset/Cora', name='Cora')
    data = dataset[0]     # 图结构拿到
	
	# 有了图结构, 图神经网络的定义,往往给定输入特征,输出特征
    net = GCNConv(data.num_features, 64)

	# 一层的正向传播, 这里把图给传进去,重要的是图节点数据和边的连接数据
    h_nodes = net(data.x, data.edge_index)   # x [2708, 1433], data.edge_index [2, 10556]

	# h_nodes就是一层GNN的输出, 基于这个,可以再进行别的运算等

之所以,我一再强调GCN前向传播逻辑的重要性, 是因为一旦梳理明白了,会很容易迁移到自己的任务上去, 上面这个代码,我加了自己的修改逻辑理解。

参考

### GNN 中的消息传递机制工作原理 图神经网络(Graph Neural Network, GNN)的核心在于其能够通过消息传递机制来捕获图结构中的复杂关系。这种机制允许节点其邻域之间的信息交换,从而使每个节点可以基于周围环境更新自身的特征表示。 #### 消息传递的基本流程 消息传递机制通常分为三个主要阶段:消息计算、聚合和状态更新。以下是这些阶段的具体描述: 1. **消息计算** 在这一阶段,每个节点会根据其当前的状态及其邻居节点的状态生成一条或多条消息。具体来说,对于某个节点 \(v\) 和它的邻居节点 \(u \in N(v)\),可以通过某种函数 \(M_t\) 计算从 \(u\) 到 \(v\) 的消息: \[ m_{t}^{(v)} = M_t(h_v^{t-1}, h_u^{t-1}) \] 这里 \(h_v^{t-1}\) 表示节点 \(v\) 在第 \(t-1\) 轮迭代时的隐藏状态,\(m_{t}^{(v)}\) 是本轮产生的消息[^1]。 2. **消息聚合** 接下来,所有来自邻居节点的消息会被汇总成单个向量。这一步骤通常由一个聚合函数 \(A_t\) 完成,常见的形式包括求和、平均值或最大池化等操作: \[ a_{t}^{(v)} = A_t(\{m_{t}^{(u)} : u \in N(v)\}) \] 此外,某些高级模型可能引入注意力机制,使得不同的邻居贡献不同程度的影响[^4]。 3. **状态更新** 最后,利用聚合后的消息对目标节点的状态进行更新。此过程一般借助于递归单元或其他类型的转换函数 \(U_t\) 实现: \[ h_v^t = U_t(h_v^{t-1}, a_t^{(v)}) \] 经过多轮这样的迭代,初始输入逐渐被丰富的上下文信息所增强,最终形成具有表达力的嵌入表示[^2]。 #### 技术挑战优化方向 尽管上述理论框架简洁明了,但在实际部署过程中仍需应对若干难题,比如内存消耗过大可能导致训练效率低下等问题[^3]。因此研究者们不断探索新的策略以提高性能并降低资源需求。 ```python import torch from torch_geometric.nn import MessagePassing class SimpleMessagePassingLayer(MessagePassing): def __init__(self): super().__init__(aggr='add') # "Add" aggregation. def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_j): # x_j has shape [E, out_channels] return x_j def update(self, aggr_out): # aggr_out has shape [N, out_channels] return aggr_out ``` 以上代码片段展示了一个基本的消息传递层定义方式,其中 `message` 方法负责创建消息,而 `update` 方法则用于整合收到的信息完成最后的状态刷新动作。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值