《Transformer-XL_Attentive Language Models Beyond a Fixed-Length Context》论文笔记

一、摘要

  1. 传统的Transformers受限于固定长度的文本。
  2. 本文提出了Transformer-XL模型这个模型使得文本的依赖能够超越固定文本的长度,并且不会产生时间上的错乱
  3. 模型由片段级别递归和新型的位置编码方案组成,主要解决了文本长距离依赖和文本碎片化问题,在时间上面也比vanilla Transformer快很多。
  4. Transformer-XL模型在enwiki8数据上取得0.99的困惑度,text8上取得1.08困惑度,WikiText-103上取得18.3的困惑度,One Billion Word取得21.8的困惑度,PennTreebank上取得54.5的困惑度。

二、介绍

  1. 对于序列数据,其中一个很大的挑战就是学习长距离依赖。
  2. RNN模型可以建立长距离依赖关系,但是无法处理梯度消失梯度爆炸的问题。虽然后来出现了LSTM模型,在一定程度上改善了这一问题,但也没有改善地很完善,普遍的研究表明LSTM模型平均学习文本依赖的长度是200.
  3. 尽管Al-Rfou在2018年提出的vanilla Transformer模型在固定文本长度的训练中取得了成功,但是这个模型并没有建立文本块之间的信息流动。因此,模型也无法捕捉超越文本长度的依赖。此外,固定文本长度的实现也无法根据语义来确定segment的边界
  4. 因为缺少了segment之间的信息流动,当模型去预测segment中的前几个单词的时候就可能会取得比较差的效果,这种现象称为文本碎片化
  5. 在本模型中,在计算segment的时候,会重用前一个segment的隐藏状态。前一个segment的隐藏状态会被缓存在内存中,用来建立segment之间的信息流动。
  6. 上一个segment的信息传递到下一个segment中,这同时也可以解决文本碎片化问题
  7. 为了能够重用之前segment的隐藏状态,模型采用相对位置编码
  8. Transformer-XL模型是第一个中自注意力模型,在字符级别和单词级别都超过RNNs的模型。

三、vanilla模型

  1. 给定语料库中的句子x=(x1,…,xT),要计算句子出现的概率P(x),就会使用下面这个公式:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-piiUcqKh-1636525500334)(D:\typora_image\image-20211028205824759.png)]

  2. 在transformer语言模型中,中心问题就是如何将任意长度的文本转变为固定长度的特征表示。

  3. 其中一个方法就是将文本切成很多个小的segment,忽略segment之间的信息流动,例如vanilla模型。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hKrISi0u-1636525522419)(D:\typora_image\image-20211028210851816.png)]

  1. 但是这种方法主要存在两个方面的限制:1、最长的依赖是segment的长度,字符级别的模型一般是几百个字符。虽然注意力模型在处理长距离依赖上面有独特的优势,但是vanilla模型并没有充分发挥这个优势。 2、无法根据语义来切分片段,仅仅是根据segment的最大长度来将语料库切分成segment,这就会导致文本碎片化问题。 3、在vanilla模型中,进行预测时每移动一个位置,都需要重头开始计算,保证获得长距离的依赖。

四、transformer-xl之片段级循环机制

  1. 在进行训练的时候,前一个segment的隐藏状态会被保存下来,将前一个segment的信息传递给当前segment的。但是,梯度只在一个当前segment中进行传播

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-my8B03CP-1636525500338)(D:\typora_image\image-20211028222811006.png)]

  2. 现有两个segment,sτ=[xτ,1,…,xτ,L]和sτ+1=[xτ+1,1,…,xτ+1,L],sτ+1的第n个隐藏状态产生的公式如下:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Cr7aY01E-1636525500340)(D:\typora_image\image-20211028212812376.png)]

    • SG(*)代表停止梯度计算
    • [hu○hv]代表两个隐藏序列在长度方向上进行concat
    • W代表模型的参数
    • 上面的公式和标准的transformer模型相比,最大的区别在于kτ+1n和vτ+1n直接受限于image-20211028213236570,间接受限于hτn-1
  3. 本模型和BPTT的区别在于,本模型会将前一个segment的整个隐藏状态进行缓存,而不是仅仅缓存最后一个隐藏状态。并且本模型使用的是相对位置编码

  4. 相比于vanilla模型,不需要每次重头开始计算,使用了隐藏状态重用机制,节省了计算消耗。

  5. 在理论上,如果内存足够,可以存之前多个segment的隐藏状态,这样可以获得更长距离的依赖。

五、transformer-xl之相对位置编码

  1. 但是使用了之前segment的隐藏状态,要如何区分呢?

  2. 如果将原本的绝对位置编码加入到片段级循环机制中。

    • 绝对位置编码U∈RLmax*d,其中第i行Ui对应于一个片段中的第i个绝对位置。

    • 如何直接将绝对编码直接加入,会得到如下公式:(E属于RL*d单词级别嵌入,f代表转换函数)

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2U4HCWix-1636525500342)(D:\typora_image\image-20211028214652342.png)]

    • 可以注意到,上面这个公式,E和Esτ+1和同样的绝对位置编码U1:L相加。这样就会导致模型无法区分xτ,j和xτ+1,j

  3. 为了避免上面这种情况,本模型加入了相对位置编码以区分不同的segment。对于query向量qτ,j之需要和keyτ,<=j进行关联,不需要知道每个key向量的绝对位置。

  4. 基于以上几点,我们讨论了加入相对位置编码。

    • 创建相对位置编码R∈RLmax*d,第i行表示和第i个token之间的相对位置距离。这样模型就能很好地区分出xτ,j和xτ+1,j

    • 并且理论上我们也不会丢失绝对位置信息,绝对位置信息可以从相对位置信息中递归恢复。

    • 原始的绝对编码时计算分数的公式如下:

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cKhFFRNV-1636525500344)(D:\typora_image\image-20211028215758368.png)]

    • 对上式进行展开,可以得到下面的公式:

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9LBJ2hma-1636525500346)(D:\typora_image\image-20211028215822668.png)]

    • 加入相对位置编码,改变得分计算公式,我们会进一步得到下面这个公式:

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TD9cQX2L-1636525500347)(D:\typora_image\image-20211028215902350.png)]

    • 对于上面这个公式,本文作者是主要做了下面三方面的改动。

      1. 第一个变化:在(b)和(d)中将绝对位置编码Uj替换为相对位置编码Ri-j,其中R是transformer中采用的不需要学习的sinsoid编码矩阵。
      2. 第二个变化:在(c)和(d)中引入了可学习的向量u∈Rd和v∈Rd来替换transformer中的query向量UiTWqT。这个改变主要是为了,因为查询向量在所有查询位置都是相同的,所以不管query位置如何,对不同单词的注意偏向应该保持一致。
      3. 第三个变化:在(a)、(b)、(c)、(d)中,Wk被拆分成Wk,E和Wk,R,也就是说输入序列和位置编码不再共享权重。
    • 新变化的相对距离公式,每一个小项又有单独的名字:

      1. (a):没有考虑位置编码的原始分数,只是基于内容的寻址
      2. (b):相对于当前内容的位置偏差
      3. (c):从内容层面衡量键的重要性,表示全局的内容偏置
      4. (d):从相对位置层面衡量键的重要性,表示全局的位置偏置
  5. 综合上面的所有步骤,总结整个模型的公式如下:(其中hτ0=E定义为简单的embedding)

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6S7iJWi4-1636525500349)(D:\typora_image\image-20211028221628292.png)]

六、实验

  1. WikiText-103数据集用来测试模型是否能够学习长距离依赖。

  2. enwik8数据集合WikiText-103数据集差不多,可以看到12层的transformer-xl和64层的vanilla transformer效果差不多

  3. One Billion Word数据集适合测试模型能否学习短距离依赖。

  4. 字符级别数据集Penny Treebank

七、消融实验

  1. 消融实验分为两部分:循环机制和新型编码方案

  2. 作者提出了一个“Full” and “half” losses,指的是将交叉熵损失应用到段中的全部或最近的一半位置

  3. 绝对编码只有在一半损耗的情况下才能很好地工作,因为一半损耗排除了训练中注意力长度很短的位置,以便更好地泛化(这里没明白)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值