GraphFormer笔记

img

NIPS2021

1 Introduction

文本属性图是一种广泛使用的数据格式。在文本属性图上的表示学习,就是根据节点文本以及邻居节点的文本信息来获得低维的 embedding。现在的做法有一种是讲预训练模型和 GNN 技术结合,即 PLM 去独立获取文本的 embedding,再使用 GNN 去聚合节点的信息。

img

这种做法叫做 Transformer-GNN 级联做法,如 图 1 所示,也就是 Transformer 的组件是部署在 GNN 的组件的前面的。实际上,因为两者独立编码,缺乏了节点文本信息之间的交换。在一些语义上具备歧义的场景会产生没有完全理解全文的现象。例如,一个节点内文本”notes on Transformer“,一个与他连接的节点内文本”tutorials on machine translation“,所以此处的 Transformer 就不太可能是指变压器。

img

作者提出了新的架构,如图 2 所示,将 GNN 和 Transformer 嵌套,叫做 GraphFormer。文本编码和图信息聚合被融合为一个迭代工作流程。在每一次迭代过程中,相互连接的节点就会交换信息。下一层的 Transformer 就会在增强的节点特征上继续编码,和级联的架构相比,这种做法产生了节点之间的信息交换和增强,得到的节点 embedding 更具备表达力。

1.1 关于GraphFormer的训练

单个节点的信息饱满程度足够应对大多数都场景,即使不去聚合邻居节点的信息,也可以解决问题。如此以来,最终训练得到的 GNN 编码器会是缺乏”锻炼的“。受到课程学习(curriculum learning)的启发,作者在 GraphFormer 的训练上分为两个过程:

curriculum learning:模型的学习呈阶段性,从简单样本上学习,到从复杂样本上学习。

  1. 将原始数据的一部分经过人工污染,让 GNN 很难仅仅通过中心节点就做出准确预测,于是就会去学习利用周围的信息。
  2. 在未经人工污染的数据上,去拟和目标分布。

1.2 单向图注意(undirected graph attention)

因为所有连接的节点都是相互依赖的,所以如果有新的中心节点需要操作,无论他周围的节点有没有被处理过,都要从头再来。这样就产生了很多不需要的计算。本文采用了单向图注意来缓解这个问题。只有中心节点需要参考周围节点,而周围节点直接独立编码,于是已经编码过的邻居节点的结果可以重用。

2 GraphFormers

在文本图数据的表示中,节点 x x x 是一个文本, N x N_x Nx 表示邻居, G x G_x Gx 表示整个图,目标是通过 embedding 的相似度预测 x q x_q xq x k x_k xk 是否存在链接。

2.1 GNN-nested Transformers

首先,中心节点和邻居节点被编码为一个 token 序列,前端带有 [cls],[cls] 表示的是一个节点的 embedding。基于词 embedding 和位置 embedding,输入序列被映射到初始 embedding 序列 { H g 0 } G \{\mathrm{H}_g^0\}_G {Hg0}G 。embedding 序列由多个 GNN-nested Transformer 层编码而得。

img

2.1.1 Graph Aggregation in GNN

如图第 l l l 层为例,首先用 Transformer 编码出初始的 token 级的编码序列 { H g 0 } \{\mathrm{H}_g^0\} {Hg0} ,每一个节点对应一个 { H g i 0 } \{\mathrm{H}_{g_i}^{0}\} {Hgi0} ,最开始,会直接将 token 级的 embedding 赋值给 node 级的 embedding, z ^ g l ← H g l \mathbf{\hat{z}}_g^l\leftarrow \mathbf{H}_g^l z^glHgl ,node 级的 embedding 是由 GNN 聚合所有节点信息之后得到的。所有的 z ^ g l \mathbf{\hat{z}}_g^l z^gl 组成了 Z G l \boldsymbol{Z}_{G}^{l} ZGl ,作为多头注意力的输入,接下来和 GAT 的操作类似,得到 GAT 的编码:
Z ^ G l = M H A ( Z G l ) ; \hat{\mathbf{Z}}_G^l=\mathrm{MHA}(\mathbf{Z}_G^l); Z^Gl=MHA(ZGl);

M H A ( Z G l ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) ; \mathrm{MHA(}\mathbf{Z}_{G}^{l})=\mathrm{Concat(}\mathbf{head}_1,...,\mathbf{head}_h); MHA(ZGl)=Concat(head1,...,headh);

h e a d j = s o f t m a x ( Q K T d + B ) V ; \mathbf{head}_{j}=\mathrm{softmax}(\frac{\mathbf{QK}^{\mathrm{T}}}{\sqrt{d}}+\mathbf{B})\mathbf{V}; headj=softmax(d QKT+B)V;

Q = Z G l W j Q ; K = Z G l W j K ; V = Z G l W j V ; \mathbf{Q}=\mathbf{Z}_G^l\mathbf{W}_j^Q; \mathbf{K}=\mathbf{Z}_G^l\mathbf{W}_j^K; \mathbf{V}=\mathbf{Z}_G^l\mathbf{W}_j^V; Q=ZGlWjQ;K=ZGlWjK;V=ZGlWjV;

公式 4 中有三个投影矩阵,对应的是第 j j j 个注意力头。聚合信息之后的 Z ^ G l \hat{\mathbf{Z}}_G^l Z^Gl 里的每一个 z ^ g l \mathbf{\hat{z}}_g^l z^gl 都会被分配到原来的节点上,并和自己 token 级的 embedding 做一个 concat 操作。
H ^ g l ← C o n c a t ( z ^ g l , H g l ) . \widehat{\mathbf{H}}_g^l\leftarrow\mathrm{Concat}(\mathbf{\hat{z}}_g^l,\mathbf{H}_g^l). H glConcat(z^gl,Hgl).

2.1.2 Text Encoding in Transformer

img

经过图聚合增强之后的 token 级的 embedding H ^ g l \widehat{\mathbf{H}}_g^l H gl 接下来由 Transformer 组件处理,
H ^ g l = L N ( H g l + M H A a s y ( H ^ g l ) ) ; \hat{\mathbf{H}}_g^l=\mathrm{LN}(\mathbf{H}_g^l+\mathrm{MHA}^{asy}(\widehat{\mathbf{H}}_g^l)); H^gl=LN(Hgl+MHAasy(H gl));

H g l + 1 = L N ( H ^ g l + M L P ( H ^ g l ) ) . \mathbf{H}_g^{l+1}=\mathrm{LN}(\widehat{\mathbf{H}}_g^l+\mathrm{MLP}(\widehat{\mathbf{H}}_g^l)). Hgl+1=LN(H gl+MLP(H gl)).

  1. 首先, H ^ g l \widehat{\mathbf{H}}_g^l H gl 经过一个非对称多头注意力层,并和原始的 token 编码作 add 并进行 Norm 操作
  2. 经过一个 MLP 投影之后再加上没投影之前的,和未投影的 token 编码作 add 并进行 Norm 操作

从而得到下一层的 token 级的输入 H g l + 1 \mathbf{H}_g^{l+1} Hgl+1 ,最后一层的输出 z x L ( i . e . , H g L [ 0 ] ) \mathbf{z}_{x}^{L} (\mathrm{i.e.},{\mathbf{H}}_{g}^{L}[0]) zxL(i.e.,HgL[0]) 就是最终的节点表示。总的时间复杂度是 O ( M ^ 2 + M P 2 ) O(\hat{M}^2+MP^2) O(M^2+MP2) ,分别是 GAT 和 Transformer 需要的时间,其中 M M M 是节点数, P P P 是token 的数量

2.2 Undirected Graph Aggregation

之前提到了关于模型的一个问题:在编码过程中,输入节点相互依赖。因此,为了为节点生成嵌入,其邻域中的所有相关节点都需要从头开始编码,而不管它们是否之前已经处理过。

作者采用了一种叫做单向图聚合的做法简化了计算:只有中心节点需要考虑周围节点的信息,而作为邻居节点,则直接由文本编码赋值即可。
H g l + 1 = { T R M l ( H x l ^ ) , g = x ; T R M l ( H g l ) , ∀ g ∈ N x . \mathbf{H}_g^{l+1}=\begin{cases}\mathrm{TRM}^l(\widehat{\mathbf{H}_x^l}), g=x;\\\mathrm{TRM}^l(\mathbf{H}_g^l), \forall g\in N_x.\end{cases} Hgl+1={TRMl(Hxl ),g=x;TRMl(Hgl),gNx.
在框架中产生的节点的 node 级编码 z ^ g l \mathbf{\hat{z}}_g^l z^gl 会被缓存,作为后续可能计算使用。

2.3 两阶段模型训练

2.3.1 训练目标

以链接预测作为训练任务,给定一个节点对 q q q k k k​ ,预测基于他们各自的 embedding,是否存在链接。最小化以下分类损失:
L = − log ⁡ exp ⁡ ( ⟨ h q , h k ⟩ ) exp ⁡ ( ⟨ h q , h k ⟩ ) + ∑ r ∈ R exp ⁡ ( ⟨ h q , h r ⟩ ) . \mathcal{L}=-\log\frac{\exp(\langle\mathbf{h}_q,\mathbf{h}_k\rangle)}{\exp(\langle\mathbf{h}_q,\mathbf{h}_k\rangle)+\sum_{r\in R}\exp(\langle\mathbf{h}_q,\mathbf{h}_r\rangle)}. L=logexp(⟨hq,hk⟩)+rRexp(⟨hq,hr⟩)exp(⟨hq,hk⟩).
h q \mathbf{h}_q hq h k \mathbf{h}_k hk 是模型得到的节点的 embedding, < ⋅ > <\cdot > <> 表示内积操作, R R R​ 表示负采样样本。

深入学习:作者在实验中采用了 in-batch negative sampling 的做法,相比其他训练方法,这种做法的优点是什么?

in-batch negative sampling:多个训练 batch,一个 batch 内的正样本将作为其他 batch 内的负样本存在

2.3.2 两阶段训练

在 GraphFomer 里,实际上中心节点和周围节点是被区别对待的,这样做会破坏模型店训练效果。具体来讲,就是中心节点的信息可以直接使用,但是周围节点的信息却需要 3 个步骤才能引入

  1. 先编码为 node 级的 embedding
  2. 经过 GAT 操作和中心节点进行聚合
  3. 引入中心节点经过图增强之后的 token 级 embedding

那么,如果中心节点的文本信息量已经足够完成预测,根本不需要考虑周围的节点信息的话。如此以来,模型的 Transformer 就会很强大,但是 GNN 模块却很弱。为了减轻这个问题,作者提出了一个”热身任务“:链接预测的训练是建立在被人工污染的数据上的。

polluted nodes:一个节点的一个 token 子集会被随机遮盖(masked)

因为被遮盖导致节点表示不足够产生精确的预测,所以模型会强迫自己大量利用周围的信息(加强周围信息的权重)学习。

第一阶段损失

L ′ = − log ⁡ exp ⁡ ( ⟨ h q ~ , h k ~ ⟩ ) exp ⁡ ( ⟨ h q ~ , h k ~ ⟩ ) + ∑ r ∈ R exp ⁡ ( ⟨ h q ~ , h r ~ ⟩ ) \mathcal{L}'=-\log\frac{\exp(\langle\mathbf{h}_{\tilde{q}},\mathbf{h}_{\tilde{k}}\rangle)}{\exp(\langle\mathbf{h}_{\tilde{q}},\mathbf{h}_{\tilde{k}}\rangle)+\sum_{r\in R}\exp(\langle\mathbf{h}_{\tilde{q}},\mathbf{h}_{\tilde{r}}\rangle)} L=logexp(⟨hq~,hk~⟩)+rRexp(⟨hq~,hr~⟩)exp(⟨hq~,hk~⟩)

波浪号标志的元素表示从被污染的节点上得到的 embedding。最小化该损失函数直到收敛。

第二阶段损失

L = − log ⁡ exp ⁡ ( ⟨ h q , h k ⟩ ) exp ⁡ ( ⟨ h q , h k ⟩ ) + ∑ r ∈ R exp ⁡ ( ⟨ h q , h r ⟩ ) . \mathcal{L}=-\log\frac{\exp(\langle\mathbf{h}_q,\mathbf{h}_k\rangle)}{\exp(\langle\mathbf{h}_q,\mathbf{h}_k\rangle)+\sum_{r\in R}\exp(\langle\mathbf{h}_q,\mathbf{h}_r\rangle)}. L=logexp(⟨hq,hk⟩)+rRexp(⟨hq,hr⟩)exp(⟨hq,hk⟩).

也就是正常的损失函数。

3 实验

3.1 数据集

  • DPLB:论文引文图,论文的标题作为文本特征
  • WIKI:维基百科的实体图,每个实体介绍中的第一个句子作为文本特征
  • Product Graph:在线产品数据集,产品名称,品牌作为文本特征

输入文本的 tokenization 是采用的子词粒度分词的 WordPiece 法。在实验中,每一个文本和 5 个均匀采样的邻居相关联,数据集规格如下:

img

在实验中,每一个测试实例中,一个 q q q 节点会被提供 300 个 k k k 节点:1 个正例和 299 个负例。 关于链接预测的衡量指标,有 3 个

  1. Precision@1:预测正确的相关结果占全部结果的比例

Precision@k = TP@k TP@k + FP@k \text{Precision@k}=\frac{\text{TP@k}}{\text{TP@k}+\text{FP@k}} Precision@k=TP@k+FP@kTP@k

  1. **NDCG(Normalized Discounted Cumulative Gain)**归一化折损累积增益:考虑了返回顺序的评价指标

N D C G @ k = D C G @ k I D C G @ k \mathrm{NDCG}@k=\frac{\mathrm{DCG}@k}{\mathrm{IDCG}@k} NDCG@k=IDCG@kDCG@k

D C G @ k = ∑ i = 1 k r e l i log ⁡ 2 ( i + 1 ) \mathrm{DCG@k}=\sum_{\mathrm{i=1}}^\mathrm{k}\frac{\mathrm{rel_i}}{\log_2\left(\mathrm{i+1}\right)} DCG@k=i=1klog2(i+1)reli

I D C G @ k = ∑ i = 1 ∣ R E L ∣ r e l i log ⁡ 2 ( i + 1 ) \mathrm{IDCG@k}=\sum_{\mathrm{i=1}}^{|\mathrm{REL}|}\frac{\mathrm{rel_i}}{\log_2(\mathrm{i+1})} IDCG@k=i=1RELlog2(i+1)reli

其中 r e l i rel_i reli 表示第 i i i 个结果的真实相关性分数。第三个公式是理想的 DCG,将 rel 集合按照从大到小的顺序排序,取前 k 个。

  1. **MRR(Mean Reciprocal Rank)平均倒数排名:**强调用户的需求项在模型推荐列表中的位置

M R R = 1 S ∑ i = 1 S 1 p i \mathrm{MRR}=\frac1S\sum_{i=1}^S\frac1{p_i} MRR=S1i=1Spi1

S S S 表示全部样本数量, p i p_i pi 表示第 i i i 个样本在模型推荐列表中的位置。

3.2 实验结果

img

3.3 消融实验

img

3.4 效率分析

每一个 mini-batch32 个节点,一张 P100,节点的 token 长度是 16

img

4 总结

  1. 将 GNN 和 Transformer 嵌套在一起,因此文本底层语义被挖掘的同时还能聚合周围邻居的信息
  2. 两阶段训练方式加强了训练质量
  3. 单向图聚合缓解了计算负担
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值