Arxiv 2307 | Retentive Network: A Successor to Transformer for Large Language Models

Retentive Network: A Successor to Transformer for Large Language Models

image.png

image.png

本文从序列建模的角度,构建了一种类似Transformer且更加高效的结构。在语言任务上展现出了良好的效率和性能。

  • 利用类似于Transformer的并行组件实现了对于GPU并行能力的利用。
  • 利用循环机制确保了 O ( 1 ) O(1) O(1)级别的存储和计算复杂度。
  • 利用分块循环策略从而执行有效的长序列建模。

实际中,并行编码每个局部的块来加速计算,同时循环编码全局块来节省显存。

序列建模

对于输入的长度为 N N N的文本嵌入序列,由于其本身信息的前后依赖关系和因果关系的需求,所以本文是从循环模型的角度开始构建模型的。

基础的迭代形式:

对于第n次迭代的输入 X n X_n Xn,有

Q n = X n ⋅ W Q , K n = X n ⋅ W K , V n = X n ⋅ W V ∈ R 1 × d Q_n = X_n \cdot W_Q, K_n = X_n \cdot W_K, V_n = X_n \cdot W_V \in \mathbb{R}^{1 \times d} Qn=XnWQ,Kn=XnWK,Vn=XnWVR1×d

将序列建模认为成通过状态 S n S_n Sn,将 V ( n ) V(n) V(n)映射为 O ( n ) O(n) O(n)**的过程。**于是可以得到下式:

S n = A s n − 1 + K n ⊤ V n = A n − 1 K 1 ⊤ V 1 + A n − 2 K 2 ⊤ V 2 + ⋯ + K n ⊤ V n = ∑ m = 1 n A n − m K m ⊤ V m S_n = As_{n-1} + K^{\top}_n V_n = A^{n-1} K^{\top}_1 V_1 + A^{n-2} K^{\top}_2 V_2 + \dots + K^{\top}_n V_n = \sum^{n}_{m=1} A^{n-m} K^{\top}_m V_m Sn=Asn1+KnVn=An1K1V1+An2K2V2++KnVn=m=1nAnmKmVm

这里的 A ∈ R d × d A \in \mathbb{R}^{d \times d} ARd×d描述了各个位置之间的相对关系。

O n = Q n S n = ∑ m = 1 n Q n A n − m K m ⊤ V m , Q n ∈ R 1 × d O_n = Q_n S_n = \sum^{n}_{m=1}Q_n A^{n-m} K^{\top}_m V_m, Q_n \in \mathbb{R}^{1 \times d} On=QnSn=m=1nQnAnmKmVm,QnR1×d

Parallel Retention

通过设置一个特殊的矩阵 A A A,将其对角化处理获得 A = Λ ( γ e i θ ) Λ − 1 A = \Lambda (\gamma e^{i \theta}) \Lambda^{-1} A=Λ(γeiθ)Λ1,这里的两个矩阵 Λ \Lambda Λ由于在公式中紧邻 Q n , K n Q_n, K_n Qn,Kn,所以可以将其合并到二者各自的权重矩阵 W Q , W K W_Q, W_K WQ,WK中一同随着网络去学习,从而上式可以改写:

O n = Q n S n = ∑ m = 1 n Q n ( γ e i θ ) n − m K m ⊤ V m = ∑ m = 1 n [ Q n ( γ e i θ ) n ] [ K m ( γ e i θ ) − m ] ⊤ V m = ∑ m = 1 n γ n − m ( Q n e i n θ ) ( K m e i m θ ) † V m O_n = Q_n S_n = \sum^{n}_{m=1} Q_n (\gamma e^{i \theta})^{n-m} K^{\top}_m V_m = \sum^{n}_{m=1} [Q_n (\gamma e^{i \theta})^{n}] [K_m (\gamma e^{i \theta})^{-m}]^{\top} V_m = \sum^{n}_{m=1} \gamma^{n-m} (Q_n e^{i n \theta}) (K_m e^{i m \theta})^{\dagger} V_m On=QnSn=m=1nQn(γeiθ)nmKmVm=m=1n[Qn(γeiθ)n][Km(γeiθ)m]Vm=m=1nγnm(Qneinθ)(Kmeimθ)Vm

这里将指数与转置融合和获得共轭转置。这里的复数系数实际上可以看做是一种位置嵌入,由于这里的计算反映出了与位置n和m的关联,所以可以认为是一种相对位置关系的表示。

由于这里 Q , K Q, K Q,K索引上的独立性,所以很容易改为并行的基于矩阵运算的结构。将复数矩阵系数极其共轭形式分别合并到 Q , K Q, K Q,K计算过程中,从而可以得到:

Q = ( X W Q ) ⊙ Θ , K = ( X W K ) ⊙ Θ ˉ , V = X W V , Θ n = e i n θ D n m = γ n − m  if  n ≥ m  else  0 Q=(XW_Q) \odot \Theta, K=(XW_K) \odot \bar{\Theta}, V=XW_V, \Theta_{n} = e^{i n \theta} \\ D_{nm}=\gamma^{n-m} \text{ if } n \ge m \text{ else } 0 Q=(XWQ)Θ,K=(XWK)Θˉ,V=XWV,Θn=einθDnm=γnm if nm else 0

从而得到整体模块的计算过程:

R e t e n t i o n ( X ) = ( Q K ⊤ ⊙ D ) V , D ∈ R N × N Retention(X) = (QK^{\top} \odot D) V, D \in \mathbb{R}^{N \times N} Retention(X)=(QKD)V,DRN×N

def ParallelRetention(
    q, # bsz ∗ num_head ∗ len ∗ qk_dim
    k, # bsz ∗ num_head ∗ len ∗ qk_dim
    v, # bsz ∗ num_head ∗ len ∗ v_dim
    decay_mask # num_head ∗ len ∗ len
 ):
     retention = q @ k.transpose(1,2)
     retention = retention ∗ decay_mask
     output = retention @ v
     output = group_norm(output)
     return output

这一形式实际上与Transformer的带mask的计算形式非常类似。

这里由于有 Q K ⊤ QK^\top QK,使用了三种归一化方式来提升数值精度,这些归一化策略实际上都是在GN输入上乘以了一个常数,而由于GN本身的尺度不变性,所以必不会影响GN的输出和反向的梯度。

  • 使用特征维度归一化 Q K ⊤ / d Q K^\top / \sqrt{d} QK/d
  • 设置 D = { D n m ∑ i = 1 n D n i } D = \{\frac{D_{nm}}{\sqrt{\sum^n_{i=1}D_{ni}}}\} D={i=1nDni Dnm}
  • 假定 R = Q K ⊤ ⊙ D R = Q K^{\top} \odot D R=QKD,设置 R = { R n m max ⁡ ( ∣ ∑ i = 1 n R n i ∣ , 1 ) } R = \{ \frac{R_{nm}}{\max(|\sum^{n}_{i=1} R_{ni}|, 1)} \} R={max(i=1nRni,1)Rnm}

Recurrent Retention

但是,如果从序列形式的角度来看,前面的最一开始的建模过程也可以改写成另外一种类似于RNN的形式。
先将状态参数写成迭代形式:

S n = γ S n − 1 + K n ⊤ V n ∈ R d × d S_n = \gamma S_{n-1} + K^{\top}_n V_n \in \mathbb{R}^{d \times d} Sn=γSn1+KnVnRd×d

最终可以得到整体迭代计算过程:

R e t e n t i o n ( X n ) = Q n S n , n ∈ { 1 , … , N } Retention(X_n) = Q_n S_n, n \in \{1,\dots,N\} Retention(Xn)=QnSn,n{1,,N}

def RecurrentRetention(
    q, k, v, # bsz ∗ num_head ∗ len ∗ qkv_dim
    past_kv, # bsz ∗ num_head ∗ qk_dim ∗ v_dim
    decay # num_head ∗ 1 ∗ 1
):
    current_kv = decay ∗ past_kv + k.unsqueeze(1) ∗ v.unsqueeze(2)
    output = torch.sum(q.unsqueeze(1) ∗ current_kv, dim=2)
    output = group_norm(output)
    return output, current_kv

实际上这里的形式与线性Attention先计算KV的思路颇有相通之处。

Chunkwise Recurrent Retention

作者也提出了一种将上述两种形式进行混合的形式,通过将序列划分为连续的块,块内部执行并行形式的处理,块之间执行循环处理,实际的,对于第 i i i个块,处理形式如下:

R e t e n t i o n ( X [ i ] ) = ( Q [ i ] K [ i ] ⊤ ⊙ D ) V [ i ] ⏟ 块内并行 + ( Q [ i ] S i ) ⊙ ξ ⏟ 块间循环 , ξ i j = γ i + 1 Retention(X_{[i]})=\underbrace{(Q_{[i]} K^{\top}_{[i]} \odot D)V_{[i]}}_{块内并行} + \underbrace{(Q_{[i]} S_i) \odot \xi}_{块间循环}, \xi_{ij} = \gamma^{i+1} Retention(X[i])=块内并行 (Q[i]K[i]D)V[i]+块间循环 (Q[i]Si)ξ,ξij=γi+1

def ChunkwiseRetention(
    q, k, v, # bsz ∗ num_head ∗ chunk_size ∗ qkv_dim
    past_kv, # bsz ∗ num_head ∗ qk_dim ∗ v_dim
    decay_mask, # num_head ∗ chunk_size ∗ chunk_size
    chunk_decay, # num_head ∗ 1 ∗ 1
    inner_decay, # num_head ∗ chunk_size
):
    retention = q @ k.transpose(1,2)
    retention = retention ∗ decay_mask
    inner_retention = retention @ v
    cross_retention = (q @ past_kv) ∗ inner_decay
    retention = inner_retention + cross_retention
    output = group_norm(retention)
    current_kv = chunk_decay ∗ past_kv + k.transpose(1,2) @ v
    return output, current_kv

Gated Multi-Scale Retention

引入“头”的概念,为不同的头使用不同的变换权重,同时为不同的头分配不同的 γ \gamma γ

image.png

image.png

由于不同的头引入了不同的参数 γ \gamma γ,所以实际的方差统计量会有所差异。所以这里使用GroupNorm独立归一化不同的头。

实际效果

仅在文本任务上进行了实验。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
回答: 本文探索了将普通的Vision Transformer (ViT)作为目标检测的骨干网络。通过对ViT架构进行微调,而无需重新设计分层骨干进行预训练,我们的普通骨干检测器可以取得竞争性的结果。研究发现,只需从单尺度特征图构建简单的特征金字塔(无需常见的FPN设计),并使用窗口注意(无需移动)辅助少量的跨窗口传播块即可。通过使用预先训练的纯ViT主干作为Masked Autoencoders (MAE),我们的检测器ViTDet可以与之前基于分层骨干的领先方法竞争,在COCO数据集上达到61.3 APbox的性能。我们希望这项研究能够引起对普通骨干检测器的关注。\[1\]\[2\]\[3\] #### 引用[.reference_title] - *1* [论文阅读-ViTDet:Exploring Plain Vision Transformer Backbones for Object Detection](https://blog.csdn.net/qq_37662375/article/details/126675811)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [ViTDet:Exploring Plain Vision Transformer Backbonesfor Object Detection(arXiv 2022)](https://blog.csdn.net/qq_54828577/article/details/127262932)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值