Openreview:
三位评委得分分别为6, 8, 10,单项满分
- 8分:方法独特,对语言模型有很大的改进,取得sofa。但有些实验做的不够,比如选择的token数量;
- 10分:非常好,ppl显著降低,利用外部知识促进语言模型,方法独特;
- 6分:只比较WikiText-103数据,其他两个数据集没比较,实验很多细节没有公布,复现性差。
先看个例子:
GNN构造:
首先对训练集所有token的隐层向量进行缓存
-
节点:2种类型节点,当前上下文的token和从外部检索得到的token
-
边:2种类型边,a0: 当前上下文的token -> 外部token,an:相同类型token之间的连接,内部token -> 内部token,外部token -> 外部token
-
节点初始化:
使用当前输入文本的向量ht,采用Faiss方法去训练集中检索最近的k个token的向量{w1, …wk}。然后取每个token的前l到后r个token的向量集合,表示当前token的节点的初始化向量表示。其中r=l=1。
这样就可以初始化所有的节点了。 -
模型层数:3
GNN节点更新:
-
Attention:
源节点和目标节点使用了不同的参数w
-
Feature:
-
Aggregate:
将attention权重点乘Feature,然后再做个W转换。
结合KNN得到最大似然概率:
这个最大似然概率公式结合了KNN和GNN两个模型,一方面需要极大的提高KNN检索相似句子的准确率,另一方面需要提高GNN-LM下一个token的概率。
实验细节:
- 数据:WikiText-103 (Merity et al., 2016), One Billion Word (Chelba et al., 2013) and Enwik8 (Mahoney, 2011)
- 检索:每个token检索top-1024 token用于KNN,top-128用于GNN。(注意:检索到的token均取了前后各1个token,加自身也就是3个token的文本)
- token泄露:为了防止预测的token在检索的节点里面,论文将第t个token以前的节点过滤掉了。如果是xl,那么memory里面的token节点也过滤掉了。但是在推理过程中没有过滤。
PS:预训练语言模型的参数是固定的
结果:
1、速度、内存、ppl:
内存:k的影响很大,但均明显增大
速度:相比base LM,慢了8-20倍
ppl:k的影响很大,但均优于base LM,从曲线来看,继续增大k的数量,ppl还会继续降低
- 内存过大解决方案:
1、先用k=32训练一个小模型,然后再基于该模型finetune一个k=128的大模型。
2、因为预训练语言模型中已经包含了较长的上下文信息,因此GNN的输入长度可以不用3072,使用128代替。
2、KNN对预测token的召回率影响:
每个范围50k个token,从表可以看到,如果knn找回的token越大,性能越好。