GCN是一种利用图结构和邻居顶点属性信息学习顶点Embedding表示的方法,GCN是直推式学习(只能在一个已知的图上进行学习),不能直接泛化到未知节点,当网络结构改变以及新节点的出现,直推式学习需要重新训练(复杂度高且可能会导致embedding会偏移),很难落地在需要快速生成未知节点embedding的机器学习系统上。
**GraphSAGE(Graph SAmple and aggreGatE)**是一种能利用顶点的属性信息高效产生未知顶点embedding的一种归纳式(inductive)学习的框架。
与GCN类似,其核心思想:学习一个映射 f ( . ) f(.) f(.),通过该映射图中的节点 v i v_i vi可以聚合它自己的特征 x i x_i xi与它的邻居特征 x j ( j ∈ N ( v i ) ) x_j \;(j \in N(v_i)) xj(j∈N(vi))来生成节点的新 v i v_i vi表示。 区别在于并未利用所有的邻居节点,聚合的方式也不同。GraphSAGE框架的核心是如何聚合节点邻居特征信息。
GraphSAGE 前向传播算法
下图是GraphSAGE的学习过程:
主要步骤如下:
(1)对邻居随机采样
(2)使用聚合函数将采样的邻居节点的Embeddin进行聚合,用于更新节点的embedding。
(3)根据更新后的embedding预测节点的标签。
更新过程:
(1)为了更新红色节点,首先在第一层(k=1)我们会将蓝色节点的信息聚合到红色节点上,将绿色节点的信息聚合到蓝色节点上。所有的节点都有了新的包含邻居节点的embedding。
(2)在第二层(k=2)红色节点的embedding被再次更新,不过这次用的是更新后的蓝色节点embedding,这样就保证了红色节点更新后的embedding包括蓝色和绿色节点的信息。这样,每个节点又有了新的embedding向量,且包含更多的信息。
算法细节如下:
需要注意以下几点:
1、 h v 0 h_v^0 hv0是每个节点的初始embedding特征向量
2、当 k = 1 k=1 k=1时,遍历所有的节点,求 h v 1 h_v^1 hv1,也就是算法的4-5行,也是最核心的部分。具体的:
(1)先对当前节点 v v v的邻居进行采样,得到邻居节点的集合 N ( v ) \mathcal N(v) N(v),对所有的邻居节点 { u ∈ N ( v ) } \{ u \in \mathcal N(v)\} {u∈N(v)}的 k − 1 k-1 k−1层的embedding: h u ( k − 1 ) = h u 0 h_u^{(k-1)}=h_u^{0} hu(k−1)=hu0 进行聚合,得到 v v v的邻居节点的代表向量 h N ( v ) k h_{\mathcal N(v)}^k hN(v)k。如何聚合后面会提到。
(2)concat操作,将的、邻居节点的代表向量 h N ( v ) k h_{\mathcal N(v)}^k hN(v)k 与自身的 h v k − 1 = h v 0 h_v^{k-1}=h_v^0 hvk−1=hv0 进行连接,然后与权重变量 W W W相乘,并进行激活。其中 W W W用于控制在模型的不同层或“搜索深度”之间传播信息。
这样求出的 h v 1 h_v^1 hv1就包含了邻居节点的信息。以此类推,当求 h v 2 h_v^2 hv2时会用到 h u 1 , u ∈ N ( v ) h_u^1,u \in \mathcal N(v) hu1,u∈N(v),而从上面的描述可知 h u 1 h_u^1 hu1已经包含了 u u u的邻居节点信息。所以在每次迭代或搜索深度时,节点从它们的本地邻居处聚集信息,随着这个过程的迭代,节点从图的更远处获得越来越多的信息。
3、随着K增大,节点可以聚合更多的信息,K既是聚合器的数量,也是权重矩阵的数量,还是网络的层数,这是因为每一层网络中聚合器和权重矩阵是共享的。网络的层数可以理解为需要最大访问到的邻居的跳数(hops),比如在figure 1中,红色节点的更新拿到了它一、二跳邻居的信息,那么网络层数就是2。
采样算法&聚合(aggragator)操作
采样算法
GraphSAGE采用了定长抽样的方法。先确定需要采样的邻居数 N N N,然后采用有放回的重采样/负采样的方法达到 N N N,这样做可以方便后期训练。
聚合(aggragator)操作
聚合方式有:平均、GCN归纳式、LSTM、pooling聚合器。(因为邻居没有顺序,聚合函数需要满足排序不变量的特性,即输入顺序不会影响函数结果)
1,平均聚合:对邻居节点的embedding中的每个维度取平均,然后与自身节点的embedding拼接后进行非线性变换。
h
N
(
v
)
k
=
mean
(
{
h
u
k
−
1
,
u
∈
N
(
v
)
}
)
h
v
k
=
σ
(
W
k
⋅
CONCAT
(
h
v
k
−
1
,
h
N
(
u
)
k
)
)
\begin{array}{c} h_{N(v)}^{k}=\operatorname{mean}\left(\left\{h_{u}^{k-1}, u \in N(v)\right\}\right) \\ h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k}\right)\right) \end{array}
hN(v)k=mean({huk−1,u∈N(v)})hvk=σ(Wk⋅CONCAT(hvk−1,hN(u)k))
2,归纳式聚合:直接对目标节点和所有邻居emebdding中每个维度取平均,后再非线性转换。
h
v
k
=
σ
(
W
k
⋅
mean
(
{
h
v
k
−
1
}
∪
{
h
u
k
−
1
,
∀
u
∈
N
(
v
)
}
)
h_{v}^{k}=\sigma\left(W^{k} \cdot \operatorname{mean}\left(\left\{h_{v}^{k-1}\right\} \cup\left\{h_{u}^{k-1}, \forall u \in N(v)\right\}\right)\right.
hvk=σ(Wk⋅mean({hvk−1}∪{huk−1,∀u∈N(v)})
3,LSTM 聚合
LSTM函数不符合“排序不变量”的性质,需要先对邻居随机排序,然后将随机的邻居序列embedding作为LSTM输入。
4,Pooling聚合:先对每个邻居节点上一层embedding进行非线性转换,再按维度应用 max/mean pooling,捕获邻居集上在某方面的突出的/综合的表现 以此表示目标节点embedding。
h
N
(
v
)
k
=
max
(
{
σ
(
W
pool
h
u
i
k
+
b
)
}
,
∀
u
i
∈
N
(
v
)
)
h
v
k
=
σ
(
W
k
⋅
CONCAT
(
h
v
k
−
1
,
h
N
(
u
)
k
−
1
)
)
\begin{aligned} h_{N(v)}^{k} &=\max \left(\left\{\sigma\left(W_{\text {pool}} h_{u i}^{k}+b\right)\right\}, \forall u_{i} \in N(v)\right) \\ h_{v}^{k} &=\sigma\left(W^{k} \cdot \operatorname{CONCAT}\left(h_{v}^{k-1}, h_{N(u)}^{k-1}\right)\right) \end{aligned}
hN(v)khvk=max({σ(Wpoolhuik+b)},∀ui∈N(v))=σ(Wk⋅CONCAT(hvk−1,hN(u)k−1))
参数学习
GraphSAGE的参数主要是聚合器的参数和权重变量 W W W。为了获得最优参数就得定义合适的损失函数。
1、有监督学习
可以使用每个节点的预测label和真实label的交叉熵作为损失函数。
2、无监督学习
其中: z u z_u zu是节点 u u u通过GraphSAGE生成的embedding;
v v v是节点 u u u随机游走可到达的"邻居"节点。
v n ∼ p n ( v ) v_n \sim p_n(v) vn∼pn(v)表示 v n v_n vn是从节点u的负采样分 p n ( v ) p_n(v) pn(v)的采样。负采样指我们还需要一批不是 u u u邻居的节点作为负样本。
Q为采样样本数。
embedding之间相似度通过向量点积计算得到。
如何理解这个损失函数?
先看损失函数的蓝色部分,当节点 u、v 比较接近时,那么其 embedding 向量 z u , z v z_u, z_v zu,zv的距离应该比较近,因此二者的内积应该很大,经过σ函数后是接近1的数,因此取对数后的数值接近于0。
再看看紫色的部分,当节点 u、v 比较远时,那么其 embedding 向量 z u , z v z_u, z_v zu,zv的距离应该比较远,在理想情况下,二者的内积应该是很大的负数,乘上-1后再经过σ函数可以得到接近1的数,因此取对数后的数值接近于0。
基于tensorflow2.0实现Graph SAGE
主要实现图的无监督学习与分类。
参考文章