Transformer-XL语言模型:超长上下文依赖

在这里插入图片描述
论文链接:https://arxiv.org/pdf/1901.02860.pdf
代码链接:https://github.com/kimiyoung/transformer-xl
参考来源:https://mp.weixin.qq.com/s/C1hXU3HMSXSY5Ru9r1CZAA

导读

今天学习的是谷歌大脑的同学和 CMU 的同学于 2019 年联合出品的论文《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》,目前被引次数超 200 次。

这篇论文提出的 Transformer-XL 主要是针对 Transformer 在解决长依赖问题中受到固定长度上下文的限制,如 Bert 采用的 Transformer 最大上下文为 512。

Transformer-XL 采用了一种 segment-level 的递归方法,不仅解决长依赖的问题,还解决了上下文碎片问题。最终,Transformer-XL 能学习到的长依赖超过 LSTM 80%,并比原来的 Transforner 多出 4.5 倍。而且 Transformer-XL 在长短序列中都获得了不错的性能,预测速度更是比原来快了 1800 多倍。

1、摘要

Transformer具有学习长依赖关系的潜力,但是受到语言建模中上下文长度固定的限制。为此,本文提出一种新的神经网络架构Transformer-XL,该网络结构能够在不破坏时间一致性的情况下,学习到超越固定长度的依赖性。该网络结构由片段级的循环机制(segment-level recurrence)和全新的位置编码策略(positional encoding scheme)组成。其优点是不仅可以捕获更长的依赖关系,还可以解决上下文碎片化(context fragmentation)的问题。从实验结果上来看,Transformer-XL 学习到的依赖性比 RNN 学习到的长 80%,比标准 Transformer 学到的长 450%,无论在长序列还是短序列中都得到了更好的结果,而且在评估时比标准 Transformer 快 1800+ 倍。值得一提的是,Transformer-XL还刷新了 bpc 和perplexity(困惑度)的当前最佳结果:在 enwiki8 上 bpc 从 1.06 提升至 0.99,在 text8 上从 1.13 提升至 1.08;在 WikiText-103 上困惑度从 20.5 提升到 18.3,在 One Billion Word 上从 23.7 提升到 21.8,在宾州树库(不经过微调的情况下)上从 55.3 提升到 54.5。本文模型的代码、预训练模型以及超参数在 TensorFlow 和 PyTorch 中都可以使用。

2、引言

语言建模需要对长期依赖性进行建模,它成功应用了无监督的预训练方法 (Peters et al., 2018; Devlin et al., 2018)。但要让神经网络对序列数据的长期依赖性建模一直都是一项挑战。RNN网络,特别是LSTM是一个标准的方案,它可以在多个benchmarks上获得健壮的结果(strong results)。尽管其使用广泛,但是RNNs由于梯度消失和梯度爆炸问题的存在,难以优化。纵使引入一些门限和梯度裁剪技术,仍然不足以完全解决该问题。此前的工作已经表明LSTM平均可以捕获200个word的上下文信息,这也指出了进一步改进的空间。

另一方面,通过attention机制直接连接长的word pairs可以缓解优化问题,同时习得长依赖(即原始的Transformer工作)。近来Al-Rfou 等人(2018)设计了一组辅助损失来训练深度 Transformer 网络进行字符级(character-level)语言建模,其结果远超LSTM。虽然已经取得成功,但是 Al-Rfou 等人(2018)的语言模型是在长度固定的几百个字符片段上独立训练的,没有任何跨片段的信息流(即多个segments,每个segment的长度固定,由数百个characters组成。但是segments之间没有信息交流)。由于上下文的长度是固定的,因此模型无法捕获任何超过预定义上下文长度的长依赖。此外,长度固定的segments都是在不考虑句子或其它语义边界的情况下,通过选择连续的符号块来创建的。因此,模型缺乏必要的上下文信息来很好地预测前几个符号,这就导致模型的优化效率和性能低下。我们将这个问题称为上下文碎片化(context fragmentation)。

为了解决上文提到的上下文固定长度的限制,本文提出了一种叫做Transformer-XL(超长)的新架构。我们将循环(recurrence)概念引入了深度自注意力网络。我们不再从头计算每个新segment的隐藏状态,而是复用从之前segments中获得的隐藏状态。被复用的隐藏状态视为当前segment的memory,而当前的segment为segments之间建立了循环连接(recurrent connection)。因此,超长依赖性建模成为了可能,因为信息可以通过循环连接来传播。同时,从之前的segment传递信息也可以解决上下文碎片化的问题。更重要的是,本文展示了使用相对位置而不是用绝对位置进行编码的必要性,这样做可以在不造成时间混乱(temporal confusion)的情况下,实现状态的复用。因此,作为额外的技术贡献,文本引入了简单但有效的相对位置编码公式,它可以泛化至比在训练过程中观察到的长度更长的注意力长度。

从单词级(word-level)到字符级(character level)的五个语言建模数据集上,Transformer-XL都获得了很好的结果。Transformer-XL在仅基于100M tokens训练的基础上也可以生成相对连贯的长文本文章。

本文的主要贡献包括:
(1)在纯粹的自注意力模型中引入了recurrence的概念,即循环连接。
(2)推导了一种新的位置编码方案。

这两种技术构成了一组完整的解决方案,因为其中任何一种单独都不能解决上下文长度固定的问题。Transformer-XL是首个从实质上不管是character-level还是word-level都比RNN更优秀的自注意力模型。

3、Transformer-XL模型

3.1、Vanilla Transformer

要想将 Transformer 应用到模型中,要解决的核心问题是如何训练 Transformer 使其可以将任意大小的上下文编码为固定大小的 Representation。

如果不考虑计算资源和内存的话,最简单粗暴的方法就是直接使用 Transformer 来对整个序列进行编码。但我们知道这种方法是不可能的。

还有一种可行但是比较粗糙的方法是将整个语料库分为多个大小相同的片段(segment),然后只在每个片段上训练而忽视所有的上下文信息,这种方法我们称为 Vanilla Transformer:
在这里插入图片描述
在预测过程中,Vanilla Transformer 也采用与训练相同大小的片段来预测最后一个位置,然后每次基于滑动窗口向右移动一个位置:
在这里插入图片描述
这种方法一定程度上确保了在预测过程中尽可能大的利用上下文,缓解了上下文碎片问题,但由于每次移动,新的片段都需要重新计算一次,所以其计算代价昂贵。

3.2、Segment-Level Recurrence

为了解决固定长度上下文的带来的问题,作者建议在 Transformer 架构中引入递归机制(Recurrence Mechanism)。在训练过程中,前一段计算出来的隐藏层状态会被被固定并缓存下来,当模型处理下一个新段时作为扩展上下文而被重用:
在这里插入图片描述
这种附加的连接可以随着网络深度的增加而增大依赖项的最大长度(想不通的可以想一下 GCN 的一阶领域)。除此之外,这种递归机制还可以解决上下文碎片问题,为新段前端的令牌提供必要的上下文信息。

我们来给出具体计算过程的数学公式:

假设现在有两个连续的分割片段 s τ = [ x τ , 1 , ⋯   , x τ , L ] s_{\tau}=[x_{\tau,1},\cdots,x_{\tau,L}] sτ=[xτ,1,,xτ,L] s τ + 1 = [ x τ + 1 , 1 , ⋯   , x τ + 1 , L ] s_{\tau+1}=[x_{\tau+1,1},\cdots,x_{\tau+1,L}] sτ+1=[xτ+1,1,,xτ+1,L] ,其中 x x x 表示 token, L L L为序列长度, s τ s_{\tau} sτ表示第 τ \tau τ 个分割片段。

假设 Transformer 有 N N N 层,那么每个片段 s τ s_{\tau} sτ 就有 N N N 个隐藏层状态,我们将第 τ \tau τ 个片段的第 n n n 个隐藏层状态表示为 h τ n h_{\tau}^n hτn, 那么第 τ + 1 \tau+1 τ+1 个片段的第 n n n 层隐藏层状态就可以通过下式得出: h ~ τ + 1 n − 1 = [ S G ( h τ n − 1 ) ∘ h τ + 1 n − 1 ] \tilde{h}_{\tau+1}^{n-1}=[SG(h_{\tau}^{n-1})\circ h_{\tau+1}^{n-1}] h~τ+1n1=[SG(hτn1)hτ+1n1] q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q T , h ~ τ + 1 n − 1 W k T , h ~ τ + 1 n − 1 W v T q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n=h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T qτ+1n,kτ+1n,vτ+1n=hτ+1n1WqT,h~τ+1n1WkT,

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值