(2024|ICLR reviewing,Transformer-VQ,自注意力线性计算时间,切片和滑动窗,缓存和迭代)

Transformer-VQ: Linear-Time Transformers via Vector Quantization

公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)

目录

0. 摘要

2. 基础

2.1 符号表示 

2.2 向量量化

2.3 向量量化器和码本

2.4 向量量化表示学习

3. Transformer-VQ

3.1 二次时间公式

3.2 预热:线性时间编码器注意力

3.3 线性时间解码器注意力

3.4 学习算法

3.4.1 训练损失

3.4.2 训练更新

5. 实验 


0. 摘要

我们引入 Transformer-VQ,这是一个 decoder-only 的 Transformer,通过线性时间计算基于 softmax 的密集自注意力。Transformer-VQ 的高效注意力是通过矢量量化的 key 和一种新颖的缓存机制实现的。在大规模实验中,Transformer-VQ 在质量上表现出色,在 Enwik8(0.99 bpb)、PG-19(26.6 ppl)和 ImageNet64(3.16 bpb)上取得了强大的结果。

代码:https://github.com/transformer-vq/transformer_vq

2. 基础

2.1 符号表示 

实数用 R 表示,扩展实数 R ∪ {−∞,∞} 用 ¯R 表示。对于所有张量,使用零为基础的索引。在对矩阵 M 进行第一个轴的索引时,我们使用 M_i 表示列向量,M_(i,:) 表示行向量。函数 LN(·)、Softmax(·)、Concat(·) 分别表示 LayerNorm(Ba等人,2016)、softmax 和连接,每个都逐行应用。符号 ≜、∝、⊙、exp(·)、δ_a,b、SG(·) 分别表示定义相等、成比例、逐元素乘积、逐元素指数、Kronecker δ 函数和停梯度运算符。

我们假设熟悉 transformers(Vaswani等人,2017),并使用符号 D_m 表示模型宽度,H 表示每层的注意力头数,D_k 表示 query/key 矢量宽度,D_v 表示 value 矢量宽度,D_f 表示前馈扇出宽度。

2.2 向量量化

向量量化(VQ)是在这项工作中广泛使用的一种技术。在这里,我们简要回顾 VQ,阐述其在自注意力中的应用动机,并讨论由 van den Oord等人(2017)引入的用于表示学习的 VQ 方案。所有证明都在附录A中给出。

2.3 向量量化器和码本

定义2.1. 向量量化器是一个函数 VQ(·;C),其定义域为 R^D,值域为 R^D。对于输入 x,其输出 ˆx为 

其中 C ∈ R^(S×D) 被称为码本。C 的行索引 {0,...,S−1} 称为短码,行本身称为码字。

定理 2.2.(基于 Guo等人(2019))。设q ∈ RD是一个随机变量,满足 E_q[qq^⊤] ∝ I_D,而 k ∈ R^D 是与 q 独立的随机变量。设 φ: R^D → R^D是一个确定性函数。那么

推论2.3. 假设定理 2.2 的条件成立。在

的约束下,选择 φ(·) = VQ(·;C) 使得公式 3 中左侧的期望最小化。

推论2.4. 假设定理 2.2 的条件成立。对于 ˆk = VQ(k;C),我们有

备注 2.5. 由于找到全局最小化器

可能很昂贵,我们使用 van den Oord等人(2017)的流式 k 均值的小批量变体进行近似。

2.4 向量量化表示学习

定义 2.6.(基于 van den Oord等人(2017))。带有直通估计器的向量量化器是一个具有定义域域 R^D 和值域 R^D 的函数 STVQ(·;C)。对于输入 x,其输出 ˆx 为

备注 2.7. 对于任意 x ∈ R^D,STVQ(x;C) 的计算结果等同于 VQ(x;C)。然而,相对于其输入的量化器的雅可比,现在将在所有地方都是 1 的矩阵,而不是几乎处处是 0 的矩阵。直观地说,使用STVQ 时,梯度被从其量化对应物中 “移植” 到未量化向量上。

备注 2.8. 我们重载符号 STVQ(·;C) 以逐行操作矩阵 value 输入。 

3. Transformer-VQ

我们现在提出 Transformer-VQ,这是一个 decoder-only 的 Transformer,可以在线性时间内计算密集自注意力。所有理论结果的证明都在附录 A 中给出。 

3.1 二次时间公式

定义 3.1. 矢量量化自注意力是一个函数 VQAttn(· ; C,W_{Q,K,V,G,O}),其定义域为 R^(T×D_m),共域为 R^(T×D_m)。对于输入 X,其输出 Y 定义为通过下式得到:

其中,τ 是一个固定常数,ϕ_v、ϕ_g、ϕ_w 是逐元素或逐行的非线性,query / key 的 LayerNorm 使用单位增益和零偏置,STVQ(·;C) 表示对矢量量化的逐行应用,使用直通梯度估计器(straight-through gradient estimator)(van den Oord et al., 2017)。

备注 3.2. 我们的注意力机制应用于受 Hua 等人(2022)启发的门控激活单元(gated activation unit,GAU)设计。GAU 是一个单头门控注意机制,通常使用 D_k = 128,D_v = 2D_m,其中两个 GAU 替换一个单一的 Transformer 层。这产生了与 Transformer 层相似的参数计数和计算要求,假设后者使用 D_m ≫ 128,D_k = D_v = D_m / H,以及 D_f = 4D_m。

备注 3.3. 先前的研究还在注意力中对 query / key 应用了 LayerNorm 或类似的方法(Henry et al., 2020; Roy et al., 2021; Zhu et al., 2021; Wu et al., 2022; Hutchins et al., 2022),通常发现这有助于提高数值稳定性和收敛性。

3.2 预热:线性时间编码器注意力

定理 3.4. 假设对于所有 i,j,B_(i,j) = 0,并且 ϕ_w 是逐行的 softmax 非线性。那么在定义 3.1 中的注意权重可以被分解为:

在这里,δ_(·,·) 表示 Kronecker δ 函数,z_t 是时步 t 的 VQ 短码。 

3.3 线性时间解码器注意力

定理 3.6. 设 L 是 T 的除数。假设对于 j > i,B_i,j = −∞,对于 j < i − L,B_i,j = 0。设 ϕ_w 是逐元素非线性,其中 ϕ_w(−∞) = 0。对于张量 M,令 M^(...,n,...) 表示切片 M_(...,nL,(n+1)L,....)。对于特定的张量,如果某个轴没有被切片,每个省略号将被相应数量的 “:” 替换。那么在定义 3.1 中的积 WV 可以使用递归计算: 

其中,如果任何块切片索引 n 小于零,则任何张量切片 M^(...,n,...) 在切片的维度上被定义为宽度为 L 的零张量。

定理 3.7. 设 L 是 T 的除数。假设对于 j > i,B_i,j = −∞,对于 j < i − L,B_i,j = 0。设 ϕ_w 是逐行 softmax 非线性。令

那么在定义 3.1 中的积 WV 可以使用递归计算:

备注 3.8. 直观地说,定理 3.7 表明,VQ-Attention 可以通过以长度为 L 的块处理序列,对每个块应用两个步骤来计算。第一步是形成相应的 Δ 块,并用其将 value 向量和短码指示符加到 “缓存” 变量 U(n),L(n) 的适当行。第二步是在检索过程中直接利用码本 C 将 U(n),L(n) 合并。

备注 3.9. 定理 3.7 提供了一种从 queries, keys, values, gates 和码本中计算 VQ-Attention 的算法,每个查询块的时间复杂度为 O(L(S + 2L)(D_k + D_v)),因此每个序列的时间复杂度为 O(T(S + 2L)(D_k + D_v))。

备注 3.10. 在实验中,我们使用 ϕ_w 作为逐行 softmax,并使用 Dai 等人(2019)的相对位置偏差(relative positional biases)来表示 B 中非零偏差的带。我们依赖于定理 3.7 的数值稳定重构,其中计数 L(n − 2) 的对数被移入 AV 和 A1 中的指数 exp(Q^(n,:)C^T)。

3.4 学习算法

3.4.1 训练损失

设 θ 表示具有 N 个 VQ-Attention 层的 Transformer 的非码本数集,

表示层次的代码簿集。对于序列

的自回归建模,我们定义 Transformer-VQ 训练损失为:

其中 β > 0 是一个超参数,是 commit loss 系数的,且 

因此,训练损失是下一个标记的平均交叉熵损失,加上标记的平均 commit loss(van den Oord et al., 2017),在码本上求和。非码本参数 θ 从两个损失项都获得梯度。按照 van den Oord 等人(2017)的做法,码本通过平滑的量化器统计参数化。

3.4.2 训练更新

与更新上述完整序列损失的方式不同,我们通常每隔 K 个 query 块进行一次更新,其中 LK ≪ T,这类似于先前研究中使用的策略(Dai et al., 2019; Wu et al., 2022; Hutchins et al., 2022)。

通过 LK 个时间步的窗口反向传播,来进行每次更新,梯度是从上述每个标记的平均损失中计算得到的相应项。码本也在每隔 K 个 query 块进行一次更新。

当 K = 1 时,使用定理 3.7 是使用不可微分的长程 key-value 缓存的高效替代方法。当 K > 1 时,学习信号通过反向传播窗口内添加到压缩缓存中的任何 value 向量发送。

5. 实验 

BPB:bits-per-byte(越低越好) 

  • 18
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值