快速Transformer解码:Multi-query Attention

222 篇文章 1 订阅
216 篇文章 0 订阅

2019年11月谷歌的论文“Fast Transformer Decoding: One Write-Head is All You Need“。

Transformer神经序列模型中使用的多头注意层,是RNN的替代。虽然整个序列的并行性让这些层的训练通常快速而简单,但由于重复加载大型“Key”和“Value”张量的内存-带宽成本,增量推理(这种并行化不可能的情况下)通常运行缓慢。本文提出一种称为多查询注意的变型,其中Key和Value在所有不同的注意“头”之间共享,从而大大减少了这些张量的大小,继而减少了增量解码的内存-带宽要求。实验验证,产生的模型确实可以更快地解码,并且与基线相比只会有轻微的质量下降。

下面的代码描述了一个点积注意公式,其中权重被计算为具有不同Key的查询点积softmax:

添加图片注释,不超过 140 字(可选)

注:代码示例使用了在TensorFlow和numpy中定义的einsum表示法,用于任意维度张量之间的广义contractions计算;在这种表示法中,一个方程命名输入和输出的张量维度;该计算在数字上等效于广播每个输入得到所有维度的并集、按分量相乘、在所有维度上求和。

Transformer的序列-序列模型[Vaswani 2017]并行使用h个不同的注意层(头部),称之为“多头注意”。h个不同层的查询向量来自输入向量对应h个不同学习的线性投影Pq。类似地,Key和Value来自m个不同输入向量对应h个学习的线性投影Pl,Pv。h层的输出本身是通过不同的学习线性投影Po,然后求和得到。为了简单起见,输入和输出向量取相同的维数d。其计算可以表示如下:

添加图片注释,不超过 140 字(可选)

注:一个常数尺度因子在此忽略。

在实践中,将多个查询批处理在一起要高效得多。下面的代码添加了两种类型的批处理。首先,从序列中n个不同位置生成查询。这些查询都用相同Key和 Value交互。此外,同时处理一批b个不同的非交互序列。根据[Vaswani2017],在自回归模型中可以在包含值-∞的logits添加“掩码”来防止反向信息流。

添加图片注释,不超过 140 字(可选)

在某些设置中,由于数据依赖性,无法并行处理来自多个位置的查询。一个例子是自回归语言模型中的自注意层,如Transformer[Vaswani 2017]。在每个位置产生的查询涉及在该位置之前(包括该位置)产生的KV对。在训练过程中,真值目标序列是已知的,可以高效并行实现。然而,特定位置处的自注意层输出,当训练的模型生成时,会影响下一位置处生成的token,又影响到下一个位置这个自注意层的输入。这个阻止了并行计算。下面显示了用于增量计算该自注意层的代码:

添加图片注释,不超过 140 字(可选)

如下提出多查询注意,作为多头注意的变型。多头注意由多个并行的注意层(头)组成,对查询、Key、Value和输出进行不同的线性转换。多查询注意是相同的,只是不同的头共享一组Key和Value。(增量)多查询(自)注意的代码与上面列出的多头注意的代码相同,只是从tf.einsum方程中删除了字母“h”,其表示K、V、Pk或Pv的“头”维度。

添加图片注释,不超过 140 字(可选)
添加图片注释,不超过 140 字(可选)

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值