图神经网络入门篇Graph Neural Network

前言

在我上一篇博客,介绍基于random walk的节点表示方式,该方法的主要是思想是以one-hot的形式,经过Embedding层得到node vector,然后优化以下的似然函数来得到最优的Embedding Matrix

m a x ∑ u ∈ V l o g P ( N R ( u ) ∣ z u ) max \sum_{u \in V} logP(N_R(u)|z_u) maxuVlogP(NR(u)zu)

该方法有很多缺点

  • 参数没有共享,一个节点对应一个embedding值
  • 图通常需要用到节点特征,该方法没有办法结合节点特征

本文将会介绍基于GNN的表示方式,尽可能解决以上的问题。

一种简单的方法

图像有CNN,序列问题有RNN,但是对于图结构来说,这些模型都不适用,图的节点数量不固定,通常都会有很复杂的拓扑结构,无论是CNN还是RNN都没有办法处理这样动态的数据结构,那么如何解决这个问题呢?

最简单的方法就是采用图的邻接矩阵,并且把节点的特征拼接进来,再把拼接后的数据喂给一个神经网络。

在这里插入图片描述

该方法虽然可行,但是存在一些缺点

  • 当图过大时,邻接矩阵过大,对于显存要求过于苛刻
  • 训练好的模型不适用于不同大小的图,如果我们想用现有的模型表示新加的节点embedding,需要重新训练

基于该结构,我也写了另外一篇博客详细介绍,图卷积神经网络

基于聚合操作的图神经网络

除了上面提到的邻接矩阵的方式,实际上我们更需要一种方法能直接泛化训练过程没有出现过的顶点,首先我们定义几个符号

  • G 图
  • V 图中的节点
  • A 邻接矩阵
  • X 节点特征 X ∈ R m × ∣ V ∣ X \in \Reals^{m×|V|} XRm×V m表示特征数量

从之前的博文中,我们知道,一个节点的embedding是由其邻居节点决定的,那么我们是否可以使用神经网络来聚合其邻居节点的信息呢?答案是肯定的,我们给每一个节点根据其邻居定义一个计算图,如下图所示。

在这里插入图片描述

这样的结构让节点在每一层都有其embedding表示,其中第一层是每个节点的特征即 x x x,第k层是经过k跳后的节点embedding信息(层约深,获取到的全局信息就越多)图中灰色部分表示的是聚合操作,聚合操作有很多种方式

Average neighbor messages

我们先看一种常用的方式,取均值。假如我们要计算 v v v节点的embedding,首先始化第0层的embedding等于其节点特征

h v 0 = X v h^0_v=X_v hv0=Xv

然后计算下一层的emebdding,可以看到,计算分为了两部分,第一块是对 v v v节点的邻居节点的上一层的embedding取均值,然后与 v v v节点的上一层的embedding取加权平均数,这里的两个加权值就是我们需要训练的参数,最后外面加一层非线性变化,注意,这里的 σ \sigma σ是指非线性变化,不一定是sigmoid函数

h v k = σ ( W k ∑ u ∈ N ( v ) h u k − 1 ∣ N ( v ) ∣ + B k h v k − 1 ) h_v^k=\sigma(W_k \sum_{u \in N(v)}\frac{h_u^{k-1}}{|N(v)|} + B_k h_v^{k-1}) hvk=σ(WkuN(v)N(v)huk1+Bkhvk1)

最后节点v的embeddding就等于最后一层的embedding

z v = h v K z_v=h^K_v zv=hvK

GraphSAGE

上一节我们介绍了采用均值的方式来聚合邻居信息,这一节我们来看一些聚合方法的变体。GraphSAGE 的全名是Graph SAmple and aggreGatE,该模型分为了三个步骤

在这里插入图片描述

  • 对邻居节点采样
  • 聚合邻居信息
  • 得到节点embedding并对下游任务进行训练

可以发现,该模型相比于上一节介绍的内容添加了一个采样的过程,主要原因是为了提高效率,某些节点可能会有特别多的邻居,如果全部计算的话消耗的计算量会很大。设采样个数为k,如果节点邻居数少于k,则采用有放回的抽样方法,直到采样出k个顶点,如果邻居数大于k,则采用无放回的抽样。接下来我们详细看下算法步骤。

在这里插入图片描述
关键步骤是4、5、7,其中第四行对所有邻居节点节点进行聚合操作,第五行把聚合后的邻居节点embdding与上一层的节点的embdding进行拼接(上文是采用加权平均的方式),然后送入一层全连接层并添加一个激活函数,通常采用tanh。第七行会对每一层的计算结果做一个归一化的操作。

在graphSAGE中,aggregate有三种方式

  • Mean aggregato 和上文提到的一样,相对邻居取均值再拼接

在这里插入图片描述

  • LSTM aggregat 把邻居节点送入到lstm中得到最终的聚合值,lstm的优势在于能有更好的表示,但是lstm主要是用来处理顺序问题,不是对称的,也就是说输入顺序的不同对输出值会有影响

  • Pooling aggregator
    在这里插入图片描述

该方法是相对每个邻居节点做一个非线性变化,然后进行pool操作(对位取最大值或者均值)

如何训练

知道了怎么计算embedding,那下一步就是考虑怎么训练模型,怎么定义损失函数了,这里有两种方法

  • 非监督学习,和random walk那篇博客讲的方法一样,不再赘述
  • 监督学习,采用节点分类任务训练embedding

这里着重说下监督学习,我们对所有的节点进行一个分类任务,假如是二分类,那么损失函数就是一个交叉熵损失,其中 y v y_v yv是节点的真实标签, θ \theta θ是分类任务的训练参数。

L = ∑ v ∈ V y v l o g ( σ ( z v T θ ) ) + ( 1 − y v ) l o g ( 1 − σ ( z v T θ ) ) L = \sum_{v\in V}y_v log(\sigma(z_v^T \theta))+(1-y_v)log(1-\sigma(z_v^T \theta)) L=vVyvlog(σ(zvTθ))+(1yv)log(1σ(zvTθ))

训练的过程,我们可以把多个节点的embedding作为一个batch,如下图是3个节点对应的embedding

在这里插入图片描述

参数主要有三部分,分类任务的 θ \theta θ、生成embedding的 W K W_K WK B k B_k Bk,这些参数对于不同的节点都是共享的

在这里插入图片描述

训练好的模型,只要是同样的场景都可以使用,例如我们对某有机物A构建了其蛋白质结构图,同样适用于有机物B。在工业场景中,图中新加一个节点也是很常见的情况,特别是社交网络这样的图,我们依旧不需要重新训练模型,直接对新加的节点使用训练好的神经网络进行embedding的生成即可。

总结

本文介绍了图神经网络的一些基本模型,当然图神经网络还有更多其它的变体,例如聚合操作,实际上还有更加优秀的处理方式,例如采用attention机制,在下一篇博客中,我会为大家详细介绍。

References

cs224w 8. Graph Neural Networks
Inductive Representation Learning on Large Graphs

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值