【图神经网络】 GraphSAGE 原文精讲(全网最细致篇)

请添加图片描述

论文精讲部分

很久没更新了,最近事情比较多,之前本专栏深入探讨图神经网络模型相关的学术论文(GCN,GAT),并通过实际代码来提高理解。但是还是少了一个经典模型才算是对这部分有了一个全面的理解,那就是GraphSAGE,为了填补这个遗憾。我在本章节详尽解读了“GraphSAGE模型”的原始论文,适合希望深入了解的读者。对于零基础或初学者,推荐阅读我编写的入门内容,它非常适合新手,点击这里查看易懂的GNN入门材料

在本章节中,大部分内容将直接引用原论文翻译,而对论文的理解和分析则以引用格式呈现,旨在清晰展示对当前研究主题的深入见解。考虑到相关资料相对稀缺,希望各位可以给个免费的点赞收藏。

在这里插入图片描述

😃如果您觉得内容有帮助,欢迎点赞、收藏,以及关注我们哦!👍

论文地址https://arxiv.org/pdf/1706.02216

0. 摘要

在大图中的 节点的低维嵌入在多种预测任务中证明了其极大的有用性, 从内容推荐到识别蛋白质功能。

这里就是再说GNN的聚合方式生成的节点嵌入表示,对预测任务比如分类还是比最初的特征生成的性能要好的。

然而,大多数现有方法要求在嵌入训练过程中图中的所有节点都必须存在

让我们用一个例子来阐明这一点:假设 ( AX ) 表示的操作,其中 ( A ) 是邻接矩阵,而 ( X ) 是节点的特征矩阵。像GCN(图卷积网络)或GAT(图注意力网络)这样的方法通常都会采用这种方式进行图卷积操作。具体来说,这种矩阵乘法操作在结构信息的指导下聚合了邻居的信息。在这一过程中,所有节点的信息都会被更新一次。上述描述的正是在加入新节点时,需要重新训练模型的原因。也就是说,在 ( AX ) 中加入一个新节点,模型的训练参数必须重新调整,进而重新开始训练。
在使用GCN时,所有节点(包括训练和测试节点)的特征都会更新,但我们仅使用训练节点的分类结果来调整参数。实际上,测试节点的表示已经在训练过程中固定下来。这种方式的一个显著问题在于,引入新的节点意味着模型需要重新训练。
此外,第二个问题是大规模图上的乘积操作实施起来非常困难,这对计算能力提出了极大的挑战。事实上,在Python环境下,几千维度的矩阵乘法已经非常耗费计算资源。对于数百万节点的图,常规的图卷积操作几乎无法实现。
因此,总结起来,我们面对的两个主要问题是:一是大规模图的训练实现困难;二是引入新节点需要重新训练整个模型。这些问题突显了以往方法在处理动态图或大规模图时的局限性。

这些先前的方法本质上是转导(transductive)的, 并不能自然地推广到未见过的节点。

这个就和标题相呼应了,之前的方法被称为transductive。就是引入新节点原始的模型不好用,而这个Inductive就是作者的标题的模型就是可以在引入新节点后一样能用。

我们在此提出 GraphSAGE,这是一种通用的归纳框架,利用节点特征信息(例如,文本属性)高效生成未见数据的节点嵌入。

为每个节点单独训练嵌入不同,

之前的模型可以看成对全部的节点进行训练得到每个节点的嵌入表示。

我们学习一个生成嵌入的函数,

之前的就是全部节点嵌入都更新,而这个则不是。他是通过一个方法,把需要的信息都放进去,就能生成某一个节点的嵌入表示。注意这里用词是某一个而不是每个,这什么意思呢就是你需要哪个节点的表示我就通过这个函数给你生成哪个节点的表示。而之前的 ( AX ) 是我不管你要哪个我都给你更新了。然后都准备好你要什么我给你什么。举个例子你想吃一道菜一个厨师是你点单后给你做。而另一个厨师就是先把菜单上的菜都做好,等你点单了从做好的菜中拿出来给你。
这样就能看到两种模型的风格不同,显然这种点单再做菜更加符合逻辑。

通过从节点的局部邻域中采样和聚合特征来实现。

为什么之前的人不这么做呢??
这里还是要展开说说,图卷积为什么要是 A X AX AX的形式,这主要是因为,图结构的不确定性,大家想象下CNN中的卷积对象图片如果看成图结构是不是很规整的每一个节点的邻居都是8个,这样的不断的迭代更新迭代更新,所以人家有共享卷积核,权重共享机制。但是GNN使用到的图某一个节点的有三个邻居,有的有20个邻居,这样的不确定性使其不能设计一个固定的卷积核去进行卷积提取。如果你强行固定了邻居个数,就要有丢失掉有用信息的决心,这也是为什么GNN要是这样的卷积行为。所以在一系列数学推导下我们得到了 A X AX AX的各种形式变体。从而最终完成了图卷积。
这同样也限制了GNN在见到新节点的局限能力,以及大图性能。为什么说这么多呢???大家可以看看作者在摘要中提到 节点的局部邻域中采样和聚合特征来实现 这样的能力,使用过局部邻居换来的,不是全部的邻居而是局部,其限定了邻居的个数,从而实现的大图能力和他所定义的Inductive能力。

我们的算法在三个归纳节点分类基准上超越了强大的基线:我们分类基于引用和 Reddit 帖子数据在不断变化的信息图中未见节点的类别,并展示了我们的算法在使用蛋白质 - 蛋白质互动多图数据集时推广到完全未见图的能力。

上面就是常规的总结性能很强。实际上在笔者的测试试验中,这个GraphSAGE确实性能是比GCN和GAT好的,大图上更是如此。局部采样一定程度上增强了模型的鲁棒性,从而不再十分依赖于邻居信息或者是能够让模型更好的理解邻居信息进行辨别吧。好久没写文章话太多了一个摘要墨迹了半天。

1. 引言

大图中节点的低维向量嵌入已经被证明在各种预测和图分析任务中作为特征输入极其有用。节点嵌入方法背后的基本思想是使用降维技术将节点图邻域的高维信息浓缩成一个紧凑的向量嵌入。这些节点嵌入随后可以输入到下游的机器学习系统中,帮助完成节点分类、聚类和链接预测等任务。

又说了一遍这个节点的低维向量表示,就是使用GNN进行节点嵌入表示编码呗。现阶段神经网络可以看成一种降维方式。很常见的行为。扩展一点现阶段的多模态都是一样的东西就是。

然而,先前的工作主要集中在嵌入来自单一固定图的节点,许多实际应用则要求能够快速生成未见节点或全新(子)图的嵌入。

还是上面的问题,GCN就是对全部节点进行更新生成,最终会得到一个固定的嵌入表示,而在新图上则重新训练。

这种归纳能力对于高吞吐量的生产机器学习系统至关重要,这些系统在不断变化的图上运行,并不断遇到未见节点(例如,Reddit上的帖子,YouTube上的用户和视频)。

一个GCN只能在一个 固定图上进行使用,如果图发生了变化则GNCN要重训,所以这对比如YouTube这种用户的邻居变化的情况,不是很实用。

生成节点嵌入的归纳方法也有助于在具有相同特征形式的图之间进行泛化:例如,可以在来自模式生物的蛋白质-蛋白质互动图上训练一个嵌入生成器,然后使用训练好的模型轻松生成新生物数据的节点嵌入。

这里就是一种模式上的迁移,在已知的蛋白质上训练,试图通过这样的模型去应用于未见过的生物数据,或者是微调。

归纳节点嵌入问题尤其困难,相比转导设置,因为要泛化到未见节点需要将新观察到的子图“对齐”到算法已经优化的节点嵌入上。

这是什么意思呢?就是说子图对齐。上文说了局部邻居聚合,实现的节点嵌入更新。也就是说仅仅依赖于局部信息,但是问题来了。这就让你丧失了理解全局的能力了。客观的说GCN只要卷积两次就能够得到全局的信息因为第一次聚合的时候信息就在网络上进行大范围的传递了。实际上GCN性能好有一部分因素其具备全局信息。作者考虑如果使用局部邻居聚合全局信息这么获取呢??
所以下面他说:

一个归纳框架必须学会识别节点邻域的结构特性,揭示节点在图中的 局部角色以及其全局位置。

还是比较简单的,你实际一个固定的邻居个数就实现局部聚合,然后再使用nodevec进行全部节点的全局编码呗。当然这里不介绍什么是node2vec算了感兴趣的自己搜索下,或者是催我更新下。 很奇怪后面没在讨论这个全局位置是如何实现的,我这理解存在误区删了,不过这个方式确实很常见。各位可以思考下。

大多数现有的生成节点嵌入的方法本质上是转导的。这些方法中的大多数直接通过基于矩阵分解的目标来优化每个节点的嵌入,并且不会自然地推广到未见数据, 因为它们在一个单一、固定的图上的节点上做出预测

又说了一遍这个,GCN使用固定图进行训练,图固定了就是节点个数不变。以及位置都不变,回到厨师的那个例子。

这些方法可以修改以在归纳设置中运行,但这些修改往往计算成本高,需要额外的梯度下降轮次才能做出新预测。

也有使用卷积算子进行图结构学习的新方法,作为嵌入方法有前景,目前,图卷积网络(GCNs)仅在具有固定图的跨导设置中应用。在这项工作中,我们将GCNs扩展到归纳无监督学习任务,并提出一个推广GCN方法以使用可训练聚合函数(超越简单卷积)的框架。

GCN还是有很多可取之处,所以GraphSAGE可以通过设定方式修改成GCN。

目前的工作。我们提出了一个通用框架,称为GraphSAGE(采样和聚合),用于归纳节点嵌入。与基于矩阵分解的嵌入方法不同,我们利用节点特征(例如,文本属性、节点资料信息、节点度数)来学习一个泛化到未见节点的嵌入函数。通过在学习算法中结合节点特征,我们同时学习每个节点邻域的拓扑结构以及邻域内节点特征的分布。虽然我们专注于特征丰富的图(例如,具有文本属性的引用数据、具有功能/分子标记的生物数据),我们的方法也可以利用存在于所有图中的结构特征(例如,节点度数)。因此,我们的算法也可以应用于没有节点特征的图。

说了半天问题,终于说到了自己模型GraphSAGE(采样和聚合),他说了设计了一个嵌入函数。图结构的不稳定性就要修改,所以使用采样技术。这样的目的就是固定邻居个数,以及让他的这个函数固定规模。同时还指出我们同时学习每个节点邻域的拓扑结构以及邻域内节点特征的分布。如果你考虑了局部那么就会损失全局信息,所以通过node2vec获取全局的这样的结构信息。

我们没有为每个节点训练一个不同的嵌入向量,而是训练一组聚合器函数,这些函数学习从节点的局部邻域中聚合特征信息。

换言之的模型只需要知道一个节点邻居的一个结构信息,就能完成这个节点的更新不依赖于这样的 A X AX AX

每个聚合器函数从一个给定节点的不同跳数或搜索深度中汇集信息。

这里说的是不是只考虑一阶邻居同样可以考虑高阶的。

在测试或推理时,我们使用训练好的系统通过应用学习到的聚合函数生成未见节点的嵌入。

新引入一个节点,可以直接看起在图中的结构信息,在这样的限定阶数和邻居个数的情况下实现这样的邻居聚合。

借鉴之前生成节点嵌入的工作,我们设计了一个无监督损失函数,使GraphSAGE可以在没有任务特定监督的情况下进行训练。我们还展示了GraphSAGE可以在完全监督的方式下进行训练。

这个在后面可以看到其使用了node2vec进行对比说明他还设计了一个无监督的模型后面出现了再讲

我们在三个节点分类基准上评估了我们的算法,测试GraphSAGE在生成未见数据上的有用嵌入的能力。我们使用了基于引用数据和Reddit帖子数据的两个演变文档图(分别预测论文和帖子类别),以及基于蛋白质-蛋白质互动数据集的多图泛化实验(预测蛋白质功能)。通过这些基准测试,我们展示了我们的方法能够有效生成未见节点的表示,并且在不同的领域中显著优于相关的基线:我们的监督方法在分类F1分数上平均提高了51%,而不是仅使用节点特征,并且GraphSAGE始终优于一个强大的跨导基线,尽管该基线在未见节点上运行时间大约长100倍。我们还展示了我们提出的新聚合器架构相比受图卷积网络启发的聚合器提供了显著提高(平均7.4%)。最后,我们通过理论分析探究了我们方法的表达能力,表明GraphSAGE能够学习节点在图中的角色结构信息,尽管它本质上是基于特征的。

稍微夸夸自己的模型吧。

2. 相关工作

我们的算法在概念上相联系的先前的节点嵌入方法、图上学习的一般监督方法以及将卷积神经网络应用于图结构数据的最新进展相关。

基于因子分解的嵌入方法。有一些最近的节点嵌入方法,使用随机游走统计和基于矩阵因子分解的学习目标来学习低维嵌入[5, 11, 28, 35, 36]。

这种因子分解的方式就是SVD对A进行分解,生成节点的嵌入表示。同样A变了其整体的节点嵌入都会发生变化

这些方法也与更经典的谱聚类[23]、多维缩放[19]以及PageRank算法[25]有密切关系。由于这些嵌入算法直接为单个节点训练嵌入,它们本质上是传递性的,并且在最小限度上,需要额外的昂贵训练(例如,通过随机梯度下降)来对新节点进行预测。此外,对于许多这些方法(例如,[11, 28, 35, 36]),目标函数对嵌入的正交变换是不变的,这意味着嵌入空间在图之间不能自然泛化,并且在重新训练时可能会漂移。杨等人提出的Planetoid-I算法[40]是对这种趋势的一个显著例外,这是一个基于嵌入的半监督学习方法。然而,Planetoid-I在推断过程中不使用任何图结构信息;相反,它在训练过程中使用图结构作为一种正则化形式。

Planetoid 现阶段适应的cora划分方式就是这个作者进行搞得,基于训练和测试同分布测试模型性能,这个哥们大家可以看看现在好像是在创业。他们上面说的这些技术,各位可以酌情看看,按照现在的时间确实很老的技术了,但是这同样考察一个人的全面性嘛。

与这些先前的方法不同,我们利用特征信息来训练模型,以生成未见节点的嵌入。

图上的监督学习。除了节点嵌入方法外,还有大量关于图结构数据的监督学习的文献。这包括各种基于核的方法,其中图的特征向量是由各种图核导出的(参见[32]及相关引用)。还有许多最近的神经网络方法用于图结构上的监督学习[7, 10, 21, 31]。我们的方法在概念上受到这些算法的启发。然而,尽管这些先前的方法试图分类整个图(或子图),本工作的重点是为单个节点生成有用的表示。

讨论到核就是数学层面的问题,当成一个最优化的问题看待,这些都是GCN前的大量工作者做的,我们现阶段看到的都是透过公式看行为,而之前的工作则是通过行为去写公式吧。这个不展开相关工作就这样。

图卷积网络。近年来,出现了几种用于图上学习的卷积神经网络架构(例如,[4, 9, 8, 17, 24])。这些方法中的大多数不能扩展到大型图或是为整图分类设计的(或两者兼有)[4, 9, 8, 24]。然而,我们的方法与Kipf等人提出的图卷积网络(GCN)[17, 18]密切相关。原始的GCN算法[17]设计用于传递性环境中的半监督学习,并且精确算法要求在训练期间已知完整的图拉普拉斯算子。我们算法的一个简单变体可以看作是GCN框架对归纳设置的扩展,这一点我们将在3.3节中重新讨论。

感兴趣的读者可以看我的GCN专栏对这些内容的讲解。

3 提出的方法:GraphSAGE

我们方法的关键思想是学习如何从 节点的局部邻居(例如,附近节点的度或文本属性)中聚合特征信息

回到了这个关键的问题,局部聚合是如何实现的。

我们首先描述GraphSAGE嵌入生成(即前向传播)算法,该算法生成节点的嵌入,假设GraphSAGE模型参数已被学习(第3.1节)。随后,我们描述如何使用标准的随机梯度下降和反向传播技术学习GraphSAGE模型参数(第3.2节)。

3.1 嵌入生成(即前向传播)算法

在本节中,我们介绍 GraphSAGE 算法用于生成节点嵌入的前向传播算法(算法 1),该算法假设模型已经过训练,参数已固定。

这里仅仅就是要展示下模型是如何聚合的

请添加图片描述

我们先看这个伪代码,我的解释顺序适合伪代码一致的

  1. G = ( V , E ) G = (V, E) G=(V,E) v ∈ V v \in V vV所以 v v v 是全部节点的中的一个,将全部节点的初始特征作为 h v 0 h_{v}^0 hv0
  2. 遍历 K K K 这个 K K K实际上就是你要模型捕获几阶的邻居信息,同样你也会有几个聚合器。换言之可以理解成神经网络的层。 K K K决定层数量。
  3. 遍历全部节点节点一个一个更新
  4. 找到这个节点的邻居进行聚合作为 h N ( v ) 1 h_{N(v)}^1 hN(v)1就是邻居的特征,实际上这个向量是多个邻居向量通过聚合器生成的。
  5. 将邻居信息和自身信息进行拼接线性映射作为节点最终的特征向量 h v 1 h_{v}^1 hv1
  6. 直到便利完全部节点结束。
  7. 对全部节点特征进行归一化
  8. 便利结束直到全部卷积活动完成
  9. 最终将 h v K h_{v}^K hvK作为模型的输出 Z K Z^K ZK

上述就是我对这个伪代码理解,如果不讨论这个聚合器的4-5实际上就是多次的 A X AX AX操作也能完成。
下面我们看看作者是怎么解释的

具体来说,我们假定已经学习了 K K K 个聚合函数的参数(记为 AGGREGATE k \text{AGGREGATE}_k AGGREGATEk,对于所有 k ∈ { 1 , … , K } k \in \{1, \ldots, K\} k{1,,K}),这些函数负责从节点邻居处聚合信息;以及一组权重矩阵 W k W_k Wk,对于所有 k ∈ { 1 , … , K } k \in \{1, \ldots, K\} k{1,,K},这些矩阵用于在模型的不同层或“搜索深度”之间传播信息。第3.2节将描述我们如何训练这些参数。

算法 1 的直观理解是,在每次迭代或搜索深度时,节点从其本地邻居聚合信息,随着这个过程的迭代演进,节点逐渐从图的更远区域获取越来越多的信息。

这里为什么要说迭代搜索深度呢,就是说我们看伪代码的第二步k是实在控制一个节点聚合了几次邻居信息,聚合第一次节点只能聚合到一跳邻居的信息,随着次数够多,其一跳邻居同样也聚合了自身的邻居信息,这样就实现了深度信息的获取,就是 A A A X AAAX AAAX 只不过他才用了循环的方式不再采用矩阵乘法,这么做最大的优势就是可以在大图中使用,为什么说 GraphSAGE 是空域图卷积的代表就是其摆脱了数学层面的推导,直接从形式上去理解这个卷积行为。简而言之 A X AX AX 数学推导一大堆不就是为了获取邻居更新特征更新自身特征嘛,我直接聚合不久好了吗,循环也能做到为什么还要用乘法呢。所以这就是其在大图能力上的来源。不用矩阵乘法而是使用循环。换言之这是对GCN的高度理解,使用便利循环避开数学乘法。

在整个图 G = ( V , E ) G = (V, E) G=(V,E) 以及所有节点的特征 x v x_v xv,对于所有 v ∈ V v \in V vV 提供作为输入的情况下,算法 1 描述了嵌入生成过程。

我们在下文中描述如何将此过程泛化到小批量设置中。

算法 1 的外部循环中的每个步骤如下进行,其中 k k k 表示外部循环中的当前步骤(或搜索的深度),

通过便利的方式决定网络的深度,同样的逻辑重复多次就是几次图卷积行为。

并且 h k h_k hk 表示该步骤中节点的表示:首先,每个节点 v ∈ V v \in V vV 聚合其直接邻域 { h k − 1 u , ∀ u ∈ N ( v ) } \{h_{k-1}^u, \forall u \in N(v)\} {hk1u,uN(v)} 中的节点表示为单个向量 h k − 1 N ( v ) h_{k-1}^{N(v)} hk1N(v)。注意,这一聚合步骤依赖于外部循环前一次迭代生成的表示(即 k − 1 k - 1 k1),并且 k = 0 k = 0 k=0(“基础情况”)的表示定义为输入节点特征。

先有 h v 0 h_{v}^0 hv0 然后通过 h v 0 h_{v}^0 hv0计算邻居和然后更新得到 h v 1 h_{v}^1 hv1不断的更新。 h v 0 h_{v}^0 hv0最初就是节点的特征信息初始特征啥的。

在聚合邻居特征向量之后,GraphSAGE 将节点当前的表示 h k − 1 v h_{k-1}^v hk1v 与聚合的邻域向量 h k − 1 N ( v ) h_{k-1}^{N(v)} hk1N(v) 连接起来,这个连接向量通过一个具有非线性激活函数 σ \sigma σ 的全连接层,将表示转换为用于算法下一步的 h k v , ∀ v ∈ V h_k^v, \forall v \in V hkv,vV

就是将两个向量拼接用一个全连接加一个激活做向量压缩,换言之就是用神经网络降低维度。

为了方便表示,我们将在深度 K K K 输出的最终表示记为 z v ≡ h K v , ∀ v ∈ V z_v \equiv h_K^v, \forall v \in V zvhKv,vV。邻域的聚合可以通过多种聚合器架构来完成(由算法 1 中的 AGGREGATE \text{AGGREGATE} AGGREGATE 占位符所表示),我们将在下文的第 3.3 节讨论不同的架构选择。

聚合器没什么神秘的就是拿到了邻居向量你怎么处理呗,加权求和?就是GAT如果仅仅使用均值GCN如果你喜欢高级点的那就LSTM。

为了将算法1扩展到小批量(minibatch)设置。

小批量的这个伪代码我再讲讲,确实这就是这个代码的核心。

请添加图片描述

为什么要分batch一方面是算力的问题。另一方面你GCN想分不是也分不了吗。需要这样的便利方式去理解GCN的思想才能设计出这样的便利行为进行分批,各位好好理解理解。我想介绍伪代码,然后在讲透彻 这个东西,不难就是需要点具像化的思考

  1. B B B 就是batch的缩写,所以这里采用 B B B B B B就是我们第一批要更新的节点,里面存着索引信息我们将 B B B 存进 B K B^K BK.假设我们这个第一次要更新四个节点
  2. 第二行变化来了,这是一个逆序。K=4的话。就是从 B 4 B^4 B4 然后创建 B 3 B^3 B3一步一步来各位理解下他的魅力。我们这个循环执行4次一次会得到 B 4 B^4 B4 B 3 B^3 B3 B 2 B^2 B2 B 1 B^1 B1 .还有 B 0 B^0 B0 .
  3. 将咱们的 B k B^k Bk.中的节点都给 B k − 1 B^k-1 Bk1就是复制下作为 B k − 1 B^k-1 Bk1
  4. 便利 B k B^k Bk中的全部节点,
  5. 找到他们的邻居,存到 B k − 1 B^k-1 Bk1
  6. 便利结束 B k B^k Bk中的节点就表明 B k − 1 B^k-1 Bk1 创建完了。
  7. 便利生成我们后续的 B 1 B^1 B1 B 0 B^0 B0 .循环结束。这里你可能发现 B 1 B^1 B1 .的节点是最多的,涵盖了之前生成全部节点,并且还有些自己独有的。为什么要这样呢?这只是采样的部分,我们接着看后续。
  8. 这部分看起来和咱们之前的全部节点更新没啥差别。但是大家注意下之前的 v ∈ V v \in V vV现在是 v ∈ V v \in V vV, 现在是 v ∈ B 0 v \in B^0 vB0 就是我们从最大的节点集合开始,开始聚合
  9. 遍历 K K K 这个 K K K实际上就是你要模型捕获几阶的邻居信息,同样你也会有几个聚合器。换言之可以理解成神经网络的层。 K K K决定层数量。现在这里就是正序了。
  10. 遍历 B B B中的节点一个一个的更新更新。大家发现美这个 B B B会愈来越小,而且 B K B^K BK.- B 0 B^0 B0一直存在着我们.batch的节点K为几他们就更新了几次。剩下的就都一致了才完成了这样的batch更新。
  11. 找到这个节点的邻居进行聚合作为 h N ( v ) 1 h_{N(v)}^1 hN(v)1就是邻居的特征,实际上这个向量是多个邻居向量通过聚合器生成的。
  12. 将邻居信息和自身信息进行拼接线性映射作为节点最终的特征向量 h v 1 h_{v}^1 hv1
  13. 直到便利完全部节点结束。
  14. 对全部节点特征进行归一化
  15. 便利结束直到全部卷积活动完成
  16. 最终将 h v K h_{v}^K hvK作为模型的输出 Z K Z^K ZK

这个行为我之前看的时候就有些绕,如果有的通许没看懂我们再结合一个图进行展示下:

以下图为例:

请添加图片描述
假设网络层数 K = 3 K=3 K=3 ,当前batch中只有红色节点。就是只需要更新这一个点。

初始的时候令 B 3 B^3 B3 只包含红色节点;伪代码的上半部分,

采样 的时候是从 B 3 B^3 B3 B 0 B^0 B0 进行的。 B 2 B^2 B2 采样的是 B 3 B^3 B3 的1-hop邻居,同时加上 B 3 B^3 B3 本身,所以 B 2 B^2 B2 包括图中的绿色+红色节点。类似的, B 1 B^1 B1 采样的是 B 2 B^2 B2 的1-hop邻居,同时加上 B 2 B^2 B2 本身,所以 B 1 B^1 B1 包括图中的蓝色+绿色+红色节点。类似的, B 0 B^0 B0 包括图中的黄色+蓝色+绿色+红色节点。

聚合 操作就是聚合邻居的embedding,来更新自身的embedding。聚合与采样类似,也是分层进行的,只不过方向和采样相反。比如 K = 3 K=3 K=3 时,需要聚合3层,每层又需要聚合多次。下图展示了 k = 1 , 2 , 3 k=1,2,3 k=1,2,3 时的聚合情况。

k = 1 k=1 k=1 为例,此时,所有在 B 1 B^1 B1 里的节点都是目标节点,都需要聚合邻居的信息,包括如下聚合过程:
请添加图片描述

  1. 黄色节点→蓝色节点
  2. 蓝色节点→绿色节点
  3. 绿色节点→红色节点

上面→表示聚合方向。为什么是这样你思考下,外层节点是内层节点邻居。

注意所有→左边的embedding都是 h k − 1 = 0 h^{k-1=0} hk1=0 的embedding,即上一个循环时的embedding。比如第2步用的蓝色节点并不是第1步聚合得到的蓝色节点,而是上一个循环得到的蓝色节点(上一个循环为初始 h 0 h^0 h0 )。所以,上述三次聚合互不影响,可以并行进行。

当所有节点聚合完成之后,→右边的embedding变成了 h k = 1 h^{k=1} hk=1 的embedding,作为下一层 k = 2 k=2 k=2 时的左边embedding。

k = 2 k=2 k=2 时,最外层的黄色节点已经不参与计算了,此时包括如下聚合过程:

  1. 蓝色节点→绿色节点
  2. 绿色节点→红色节点

虽然绿色节点还是只聚合其直接邻居蓝色节点,但是由于蓝色节点在上一轮中聚合了黄色节点,所以绿色节点在这一轮中能够通过蓝色节点间接聚合到黄色节点,即绿色节点聚合到了其2-hop邻居。类似的,红色节点也聚合到了其2-hop邻居即蓝色节点。

k = 3 k=3 k=3 时,蓝色节点也已经不参与计算了,此时包括如下聚合过程:

  1. 绿色节点→红色节点

根据上面的分析,红色节点能间接聚合到其3-hop邻居,即最远聚合到黄色节点的信息。

三层聚合结束之后,最终我们得到了红色节点的embedding。可以看到,为了得到红色这一个节点的embedding,如果网络层数为3的话,其最终聚合了三层节点的信息。

说了半天大家再理解理解为什么说batch就是 B 0 B^0 B0 里面的节点更新了三次。

另外一个值得提醒的是,采样的过程是从 B 3 B^3 B3 B 0 B^0 B0 降序进行的,主要是为方便后续聚合的时候从从 B 0 B^0 B0 B 3 B^3 B3 进行。

我们上面基本介绍完了其算法的聚合方式。其中的聚合函数总结下就这四个:
其实还是很好理解的,就是对邻居的不同操作方式。

请添加图片描述
聚合器概述与比较:

Mean Aggregator
Mean aggregator 是最简单的聚合操作,它通过计算所有邻居的均值(不包括节点 v 本身)进行聚合,将得到的均值向量与节点自身的特征向量拼接(concat)后,再应用非线性激活函数。这种方法简单直接,易于实现。

GCN Aggregator
Mean aggregator 非常类似,GCN aggregator 的主要区别在于它在聚合邻居的信息时也包括了节点 v 自身(相当于自回路的存在)。这种设计是基于节点自身也是其邻域信息的一部分的考虑。然而,与 Mean aggregator 不同的是,在非线性激活之前,GCN 不进行自身状态的拼接,这类似于在 ResNet 架构中避免信息丢失的短路机制。因此,GCN 的网络层数通常不宜过深,且在某些情况下效果可能不如 Mean aggregator

Pooling Aggregator
Pooling aggregator 先对所有邻居的特征通过一个多层感知机(MLP),参数为 (W_{\text{pool}}) 和 (b_{\text{pool}}),进行处理,然后执行逐元素的最大池化(max pooling)操作。随后将池化结果与节点自身的特征拼接,最后应用非线性激活函数。在实验中发现,使用最大池化和平均池化在此聚合器中表现相似。

LSTM Aggregator
LSTM aggregator 是一种较为独特的聚合方式。与前述聚合器不同的是,LSTM 作为一种序列模型,对输入的顺序敏感。尽管如此,研究发现,即使对邻居进行随机排序后输入 LSTM,依然能够获得良好的聚合效果。

Graph Attention Network (GAT)
一个自然的改进思路是引入注意力机制来聚合邻居,即 Graph Attention Network (GAT)。在 GAT 中,不同邻居的特征可以根据他们与中心节点的相对重要性被不同地加权,实现更为灵活和有效的特征聚合。

啰嗦了半天我们互道论文正文。

与 Weisfeiler-Lehman 同构测试的关系:GraphSAGE 算法在概念上受到了用于测试图同构的经典算法的启发。如果在算法1中,我们(i)将 K K K设为 ∣ V ∣ |V| V,(ii)将权重矩阵设置为单位矩阵,并(iii)使用适当的哈希函数作为聚合器(无非线性操作),那么算法1就是 Weisfeiler-Lehman(WL)同构测试的一个实例,也称为"naive vertex refinement"。如果算法1为两个子图输出的表示集合 { z v , ∀ v ∈ V } \{z_v, \forall v \in V\} {zv,vV}是相同的,则WL测试宣布这两个子图是同构的。这个测试在某些情况下可能会失败,但在广泛的图类中是有效的。

GraphSAGE 是 WL 测试的一个连续近似,其中我们用可训练的神经网络聚合器替换了哈希函数。当然,我们使用 GraphSAGE 来生成有用的节点表示——而不是用来测试图同构。尽管如此,GraphSAGE 与经典的 WL 测试之间的联系为我们的算法设计提供了理论背景,以学习节点邻域的拓扑结构。

就是相同思想,在不同的问题下的应用,分组卷积的应用是为了减少计算量,但是在CNN中最初是为了增大算力。技术相似但是背景不同。

邻域定义:在这项工作中,我们统一采样一组固定大小的邻居,而不是在算法1中使用完整的邻域集,以保持每个批次的计算开销固定。

上面的两个算法都是默认邻居个数是多少个,实际上真实的GraphSAGE还是进行了一定的限制。

具体来说,我们定义 N ( v ) N(v) N(v)为从集合 { u ∈ V : ( u , v ) ∈ E } \{u \in V : (u, v) \in E\} {uV:(u,v)E}中均匀抽取的固定大小样本,并且在算法1的每次迭代 k k k中抽取不同的均匀样本。如果不进行这种采样,单个批次的内存和预期运行时间是不可预测的,在最坏的情况下为 O ( ∣ V ∣ ) O(|V|) O(V)。相比之下,对于GraphSAGE,每个批次的空间和时间复杂度是固定的,为 O ( ∑ i = 1 K S i ) O(\sum_{i=1}^K S_i) O(i=1KSi),其中 S i S_i Si i ∈ { 1 , … , K } i \in \{1, \ldots, K\} i{1,,K} K K K 是用户指定的常数。实际上,我们发现我们的方法在 K = 2 K = 2 K=2 S 1 ⋅ S 2 ≤ 500 S1 \cdot S2 \leq 500 S1S2500 的情况下可以实现高性能(详情见第4.4节)。

还是和GAT一样仅仅考虑两次的卷积行为。怕出现过平滑现象

3.2 学习 GraphSAGE 的参数

如果是半监督任务损失函数还是很好设计的,但是作者还是设计了一个无监督的任务。下面的损失函数是针对无监督算法的。什么意思呢就是做了一个负采样。让真实的邻居节点节点点积结果变大,
反之变小。这就是一个强假设,相邻阶段特征一致。

为了在完全无监督的设置中学习有用且可预测的表示,我们应用基于图的损失函数于输出表示 z u , ∀ u ∈ V z_u, \forall u \in V zu,uV,并通过随机梯度下降调整权重矩阵 W k , ∀ k ∈ { 1 , … , K } W_k, \forall k \in \{1, \ldots, K\} Wk,k{1,,K} 和聚合函数的参数。基于图的损失函数鼓励邻近节点拥有相似的表示,同时确保不同节点的表示具有高度区分性:
J G ( z u ) = − log ⁡ σ ( z u T z v ) − Q ⋅ E v n ∼ P n ( v ) [ log ⁡ σ ( − z u T z v n ) ] , ( 1 ) J_G(z_u) = -\log \sigma(z_u^T z_v) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)} [\log \sigma(-z_u^T z_{v_n})], \quad (1) JG(zu)=logσ(zuTzv)QEvnPn(v)[logσ(zuTzvn)],(1)

v v v 是真的邻居,而这个 z u z_u zu 是假的邻居,让和真的邻居点积分很大假的很小,最小化这个函数。更新参数。随机游走就是从u出发长度固定的总是不会太远,所以都可以默认是u的几阶邻居。

其中 v v v 是在固定长度随机游走中与 u u u 邻近出现的节点, σ \sigma σ 是 sigmoid 函数, P n P_n Pn 是负采样分布, Q Q Q 定义了负样本的数量。重要的是,与以往的嵌入方法不同,我们输入到这个损失函数中的表示 z u z_u zu 是从节点的本地邻域内的特征生成的,而不是通过嵌入查找为每个节点训练一个独特的嵌入。

这种无监督的设置模仿了将节点特征提供给下游机器学习应用程序的情况,无论是作为一项服务还是存储在静态仓库中。在特定下游任务仅使用表示的情况下,无监督损失(方程 1)可以简单地被替换或通过特定于任务的目标(如交叉熵损失)进行增强。

就是任务不同你换个损失函数即可。比如分类就换成交叉熵。

3.3 聚合器

>实际上和我在上面讲解了,但是比较难懂所以各位想看也可以看看。我还是保留了。

不同于在 N 维网格上的机器学习(例如,句子、图像或3D体积),一个节点的邻居没有自然顺序;因此,算法1中的聚合函数必须能够操作一组无序的向量集。理想情况下,一个聚合函数应是对称的(即,对其输入的排列不变),同时仍需可训练并保持高表征能力。聚合函数的对称性质确保我们的神经网络模型可以训练并应用于任意排序的节点邻域特征集。我们考察了三种候选聚合函数:

  • 均值聚合器(Mean aggregator):
    我们的第一个候选聚合函数是均值操作符,我们简单地取 h k − 1 u , ∀ u ∈ N ( v ) {h_{k-1}^u, \forall u \in N(v)} hk1u,uN(v) 中向量的逐元素均值。均值聚合器几乎等同于传统 GCN 框架中使用的卷积传播规则。特别是,我们可以通过替换算法1中的第4行和第5行来衍生 GCN 方法的归纳变体:
    h k v ← σ ( W ⋅ MEAN ( { h k − 1 v } ∪ { h k − 1 u , ∀ u ∈ N ( v ) } ) ) . h_k^v \leftarrow \sigma(W \cdot \text{MEAN}(\{h_{k-1}^v \} \cup \{h_{k-1}^u, \forall u \in N(v)\})). hkvσ(WMEAN({hk1v}{hk1u,uN(v)})).
    我们称这种修改后基于均值的聚合器为卷积聚合器,因为它是局部谱卷积的粗略线性近似。此卷积聚合器与我们其他提议的聚合器的重要区别在于,它不执行算法1第5行的拼接操作——即,卷积聚合器不将节点先前层的表征 h k − 1 v h_{k-1}^v hk1v 与聚合的邻域向量 h k N ( v ) h_k^{N(v)} hkN(v) 进行拼接。这种拼接可以看作是 GraphSAGE 算法不同“搜索深度”或“层”之间的一种简单形式的“跳跃连接”,并且它带来了显著的性能提升。

  • LSTM 聚合器(LSTM aggregator):
    我们还考察了基于 LSTM 架构的更复杂的聚合器。与均值聚合器相比,LSTM 由于其较大的表达能力而具有优势。然而,需要注意的是,LSTM 本质上不是对称的(即,它们对排列不是不变的),因为它们以序列方式处理输入。我们通过简单地将 LSTM 应用于节点邻居的随机排列来适应无序集。

  • 池化聚合器(Pooling aggregator):
    我们考察的最后一个聚合器既对称又可训练。在这种池化方法中,每个邻居的向量独立通过一个全连接神经网络进行处理;经过此变换后,对邻居集执行逐元素最大池化操作来聚合信息:
    AGGREGATE pool k = max ⁡ ( { σ ( W pool h k u i + b ) , ∀ u i ∈ N ( v ) } ) , \text{AGGREGATE}_{\text{pool}}^k = \max(\{\sigma(W_{\text{pool}} h_{k}^{u_i} + b), \forall u_i \in N(v)\}), AGGREGATEpoolk=max({σ(Wpoolhkui+b),uiN(v)}),
    其中 max ⁡ \max max 表示逐元素最大操作符, σ \sigma σ 是非线性激活函数。基本上,最大池化之前的函数可以是任意深的多层感知机,但我们在本工作中关注简单的单层架构。这种方法受到最近在通用点集上应用神经网络架构的先进发展的启发。直观上,多层感知机可以被视为一组为邻居集中的每个节点表征计算特征的函数。通过对计算出的特征应用最大池化操作,模型有效地捕获了邻域集的不同方面。同时,请注意,原则上可以使用任何对称向量函数来代替最大操作符(例如,逐元素均值)。我们在开发测试中发现最大池化和均值池化没有显著差异,因此在我们的实验中专注于最大池化。

4 实验

实验细节其实就是一个两层的网络然后采用了采样的方式,我还是都翻译了便于各位阅读。咱们总结见,我就不多说了。

请添加图片描述

图2:A:在Reddit数据上的时间实验,训练批次大小为512,并在完整测试集上进行推理(79,534个节点)。B:模型性能与采样邻域的大小的关系,其中“邻域样本大小”指的是在K=2且S1=S2的情况下,每个深度采样的邻居数(使用GraphSAGE-mean在引用数据上进行)。

我们在三个基准任务上测试了 GraphSAGE 的性能:(i) 使用科学引文数据集(Web of Science)对学术论文进行不同学科的分类;(ii) 将 Reddit 帖子分类为不同的社区;以及 (iii) 在各种生物蛋白质-蛋白质相互作用(PPI)图上分类蛋白质的功能。第4.1节和第4.2节总结了数据集,补充材料中包含了额外的信息。在所有这些实验中,我们对在训练期间未见过的节点进行预测,并且在PPI数据集的情况下,我们在完全未见过的图上进行测试。

实验设置
为了使我们的归纳基准的实证结果具有参考性,我们与四个基线进行比较:随机分类器、忽略图结构的基于特征的逻辑回归分类器、代表因子分解方法的 DeepWalk 算法,以及原始特征与DeepWalk嵌入的串联。我们还比较了使用不同聚合函数(第3.3节)的四种 GraphSAGE 变体。由于 GraphSAGE 的“卷积”变体是 Kipf 等人的半监督 GCN 的扩展归纳版本,我们将此变体称为 GraphSAGE-GCN。我们测试了根据方程(1)中的损失训练的 GraphSAGE 的无监督变体,以及直接在分类交叉熵损失上训练的监督变体。对于所有 GraphSAGE 变体,我们使用了修正线性单元作为非线性激活函数,并设置 K=2,邻域样本大小 S1=25 和 S2=10(见第4.4节的灵敏度分析)。

对于 Reddit 和引文数据集,我们像 Perozzi 等人描述的那样使用 DeepWalk 的“在线”训练,在做出预测之前运行一轮 SGD 优化以嵌入新的测试节点(详见附录)。在多图设置中,我们无法应用 DeepWalk,因为在不同的独立图上运行 DeepWalk 算法生成的嵌入空间可能会彼此任意旋转(见附录D)。

所有模型都在 TensorFlow 中实现,使用 Adam 优化器(除了 DeepWalk,使用传统的梯度下降优化器效果更好)。我们设计实验的目标是 (i) 验证 GraphSAGE 相比基线方法(即原始特征和 DeepWalk)的改进,以及 (ii) 对不同的 GraphSAGE 聚合器架构进行严格的比较。为了提供公平的比较,所有模型在其小批量迭代器、损失函数和邻域采样器(如适用)上共享相同的实现。此外,为了防止在比较 GraphSAGE 聚合器时不经意的“超参数黑客行为”,我们为所有 GraphSAGE 变体扫描了相同的一组超参数(根据在验证集上的性能选择每个变体的最佳设置)。可能的超参数值是基于使用引文和 Reddit 数据子集的早期验证测试确定的,然后我们将这些数据从我们的分析中丢弃。附录中包含进一步的实施细节。

4.1 演变图上的归纳学习:引用和Reddit数据

我们的前两个实验是在演变信息图中对节点进行分类,这项任务在高吞吐量生产系统中尤为重要,因为这些系统不断遇到新的数据。

引用数据
我们的第一个任务是在一个大型引用数据集上预测论文的主题类别。我们使用了一个来自汤森路透Web of Science核心合集的无向引用图数据集,对应于2000-2005年间六个生物学相关领域的所有论文。该数据集的节点标签对应六个不同的领域标签。总的来说,这个数据集包含302,424个节点,平均度数为9.15。我们在2000-2004年数据上训练所有算法,并使用2005年的数据进行测试(30%用于验证)。对于特征,我们使用了节点度数,并根据Arora等人的句子嵌入方法处理了论文摘要,使用了GenSim word2vec实现训练的300维词向量。

Reddit数据
在我们的第二个任务中,我们预测不同Reddit帖子属于哪个社区。Reddit是一个大型在线讨论论坛,用户在不同主题的社区中发布和评论内容。我们从2014年9月发布的Reddit帖子中构建了一个图数据集。在这种情况下,节点标签是帖子的社区或“子版块”。我们抽取了50个大社区并构建了一个帖子到帖子的图,如果同一个用户在两个帖子上评论,就连接这两个帖子。总的来说,这个数据集包含232,965个帖子,平均度数为492。我们使用前20天的数据进行训练,剩余的天数用于测试(30%用于验证)。对于特征,我们使用了现成的300维GloVe CommonCrawl词向量;对于每个帖子,我们分别连接了(i)帖子的标题的平均嵌入,(ii)所有评论的平均嵌入,(iii)帖子的评分,以及(iv)帖子的评论数量。

表1的前四列总结了GraphSAGE以及基线方法在这两个数据集上的表现。我们发现GraphSAGE在很大程度上优于所有基线方法,并且可训练的神经网络聚合器相比GCN方法提供了显著的增益。例如,GraphSAGE-pool的无监督变体在引用数据上比DeepWalk嵌入和原始特征的拼接高出13.8%,在Reddit数据上高出29.1%,而监督版本分别提高了19.7%和37.2%。有趣的是,尽管LSTM聚合器是为顺序数据设计的,而不是无序集合,但它表现出强劲的性能。最后,我们看到无监督GraphSAGE的表现与完全监督版本相当,表明我们的框架可以在没有任务特定微调的情况下实现强大的性能。

4.2 跨图的泛化:蛋白质-蛋白质相互作用

我们现在考虑跨图泛化的任务,这需要学习节点角色而不是社区结构。我们将蛋白质的角色(基于其在基因本体论中的细胞功能)分类到不同的蛋白质-蛋白质相互作用图中,每个图对应于不同的人体组织。我们使用位置基因集、基序基因集和免疫签名作为特征,并使用基因本体集作为标签(总计121个),这些信息来自分子签名数据库。平均每个图包含2373个节点,平均度数为28.8。我们在20个图上训练所有算法,然后在两个测试图上平均预测F1分数(使用另外两个图用于验证)。

表1的最后两列总结了各种方法在这个数据集上的准确性。我们再次看到,GraphSAGE显著优于基线方法,LSTM和pooling聚合器相比于mean和GCN聚合器提供了显著的增益。

4.3 运行时间和参数敏感性

图2.A总结了不同方法的训练和测试运行时间。各种方法在训练时间上是可比的(GraphSAGE-LSTM最慢)。然而,需要为未见节点采样新的随机游走并运行新的SGD轮次,这使得DeepWalk在测试时比其他方法慢100-500倍。

对于GraphSAGE变体,我们发现设置K=2在准确性上提供了大约10-15%的持续提升;然而,将K增加到2以上在性能提升上收益递减(0-5%),且运行时间大幅增加10-100倍,具体取决于邻居采样大小。我们还发现,采样大规模邻居同样收益递减(图2.B)。因此,尽管子采样邻居带来了更高的方差,GraphSAGE仍能保持强预测准确性,同时显著改善运行时间。

4.4 聚合器架构间的比较总结

总体而言,我们发现LSTM和基于pool的聚合器在平均性能和实验设置中的表现都最好(表1)。为了更定量地洞察这些趋势,我们将每个不同的实验设置(即,3个数据集×(无监督 vs 监督))视为试验,并考虑哪些性能趋势有可能推广。特别是,我们使用无参数Wilcoxon有符号秩检验来量化不同聚合器之间的性能差异,并在适用时报告T统计量和p值。需要注意的是,该方法基于秩,基本上测试在新的实验设置中我们是否期望一种特定方法优于另一种。考虑到我们的样本量只有6种不同设置,这个显著性检验的能力有限;尽管如此,T统计量和相关p值是评估聚合器相对表现的有用定量措施。

我们看到LSTM、pool和mean聚合器相比GCN方法都提供了统计显著的增益(T = 1.0,p = 0.02)。然而,LSTM和pool方法相比mean聚合器的增益更为边际(T = 1.5,p = 0.03,比较LSTM与mean;T = 4.5,p = 0.10,比较pool与mean)。LSTM和pool方法之间没有显著差异(T = 10.0,p = 0.46)。然而,GraphSAGE-LSTM比GraphSAGE-pool显著更慢(大约慢2倍),这可能在整体上略微有利于基于池化的聚合器。

5 总结

终于是完事了,把GNN基础模型进行了梳理,但是下一章节咱们还要继续的来代码实战部分,希望感兴趣的各位能够加油继续学习起来。

GraphSAGE是GNN领域非常经典的文章,现在看来感觉所有想法都很自然啊,图不就应该聚合邻居来更新自身吗,但当时相出这个方法可能还是有其历史难度。
https://bitjoy.net/2022/05/31/%E3%80%8Ainductive-representation-learning-on-large-graphs%E3%80%8B%E9%98%85%E8%AF%BB%E7%AC%94%E8%AE%B0/
引用一下,这个博主的话我上图的例子就是参考他的理解引用了一部分进行讲解的。因为我觉得语言在面对这个例子的时候还是很匮乏的。大家感兴趣可以直接阅读他的原文。

如果您觉得还不错的话,可以奖励打赏小弟一杯咖啡钱,创作不易。如果你对GNN感兴趣,不妨点赞、收藏并关注,这是对我工作的最大支持和鼓励。非常感谢!如果有任何问题,欢迎随时私信我。期待与你的互动!
在这里插入图片描述

### GraphSAGE 代码实现 GraphSAGE(Graph Sample and Aggregate)是一种用于图神经网络中的节点表示学习方法,其核心在于通过采样和聚合邻居信息来构建节点的特征表示。下面展示的是基于 PyTorch 实现的一个简单版本的 GraphSAGE 模型[^2]。 #### 定义 SAGEConv 层 ```python import torch from torch.nn import Parameter, init from torch_geometric.nn.conv import MessagePassing import torch.nn.functional as F class SAGEConv(MessagePassing): def __init__(self, in_channels, out_channels): super(SAGEConv, self).__init__(aggr='mean') # "Mean" aggregation. self.lin = torch.nn.Linear(in_channels * 2, out_channels) def forward(self, x, edge_index): # Start propagating messages. return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x) def message(self, x_j): # x_j has shape [E, in_channels] return x_j def update(self, aggr_out, x): # aggr_out has shape [N, in_channels] combined = torch.cat([x, aggr_out], dim=-1) return self.lin(combined) ``` 这段代码实现了单层的 `SAGEConv` 卷积操作,其中包含了消息传递机制以及更新函数。这里采用均值作为聚合策略。 #### 构建完整的模型架构 ```python import torch.nn as nn class GraphSAGE(nn.Module): def __init__(input_dim, hidden_dims, output_dim): super(GraphSAGE, self).__init__() layers = [] dims = [input_dim] + list(hidden_dims) + [output_dim] for i in range(len(dims)-1): conv_layer = SAGEConv(dims[i], dims[i+1]) act_func = nn.ReLU() if i != len(dims)-2 else None block = nn.Sequential( conv_layer, act_func ) layers.append(block) model_layers = nn.ModuleList(layers[:-1]) # Exclude last activation function final_block = nn.Sequential(*layers[-1:]) model_final = nn.ModuleDict({'final': final_block}) setattr(model_layers, 'model', model_final['final']) setattr(GraphSAGE, '_modules', {'conv_blocks': model_layers}) def forward(x, edges): h = x.clone() for layer in getattr(GraphSAGE._modules['conv_blocks'], 'children')(): h = layer(h, edges) return h ``` 此部分定义了一个多层的 GraphSAGE 网络结构,每一层都由前面创建好的 `SAGEConv` 组件构成,并且加入了激活函数以增加表达能力。后一层不加非线性变换以便于后续任务处理,例如分类或回归等问题。 #### 训练过程概览 为了训练上述建立起来的 GraphSAGE 模型,在给定的数据集上执行以下步骤: - 初始化模型参数; - 对输入特征矩阵 X 和邻接表 E 进行预处理; - 使用负对数似然损失或其他适合的任务特定目标函数计算误差; - 应用随机梯度下降法调整权重直至收敛; 具体的训练细节可以根据应用场景的不同而有所改变,比如超参的选择、正则化技术的应用等[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

这个男人是小帅

请小弟喝杯咖啡☕️鼓励下吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值