论文地址:https://arxiv.org/pdf/2405.04434
相关博客
【自然语言处理】【大模型】语言模型物理学 第3.3部分:知识容量Scaling Laws
【自然语言处理】Transformer中的一种线性特征
【自然语言处理】【大模型】DeepSeek-V2论文解析
【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM
【自然语言处理】BitNet b1.58:1bit LLM时代
【自然语言处理】【长文本处理】RMT:能处理长度超过一百万token的Transformer
【自然语言处理】【大模型】MPT模型结构源码解析(单机版)
【自然语言处理】【大模型】ChatGLM-6B模型结构代码解析(单机版)
【自然语言处理】【大模型】BLOOM模型结构源码解析(单机版)
一、简介
- DeepSeek-V2是一个总参数为236B的MoE模型,每个token仅激活21B的参数,并支持128K的上下文长度。
- 提出了Multi-head Latent Attention(MLA),通过压缩kv cache至隐向量,从而保证高效推理。
- 相比于DeepSeek 67B,DeepSeek-V2实现了更好的表现,节约了42.5%的训练成本,降低了93.3%的kv cache,提升最大吞吐5.76倍。
- 预训练语料包含了8.1T tokens并进一步进行SFT和RL。
二、模型结构
1. MLA(Multi-Head Latent Attention)
传统Transformer采用MHA(Multi-Head Attention),但是kv cache会成为推理瓶颈。MQA(Multi-Query Attention)和GQA(Grouped-Query Attention)可以一定程度减少kv cache,但效果上不如MHA。DeepSeek-V2设计了一种称为MLA(Multi-Head Latent Attention)的注意力机制。MLA通过低秩key-value联合压缩,实现了比MHA更好的效果并且需要的kv cache要小很多。
1.1 标准MHA
令 d d d为embedding维度, n h n_h nh是注意力头的数量, d h d_h dh是每个头的维度, h t ∈ R d \textbf{h}_t\in\mathbb{R}^d ht∈Rd是注意力层中第 t t t个token的输入。标准MHA通过三个矩阵 W Q , W K , W V ∈ R d h n h × d W^Q,W^K,W^V\in\mathbb{R}^{d_h n_h\times d} WQ,WK,WV∈Rdhnh×d来产生 q t , k t , v t ∈ R d h n h \textbf{q}_t,\textbf{k}_t,\textbf{v}_t\in\mathbb{R}^{d_h n_h} qt,kt,vt∈Rdhnh。
q t = W Q h t k t = W K h t v t = W V h t \begin{align} \textbf{q}_t&=W^Q\textbf{h}_t \tag{1}\\ \textbf{k}_t&=W^K\textbf{h}_t \tag{2}\\ \textbf{v}_t&=W^V\textbf{h}_t \tag{3}\\ \end{align} \\ qtktvt=WQht=WKht=WVht(1)(2)(3)
在MHA中 q t , k t , v t \textbf{q}_t,\textbf{k}_t,\textbf{v}_t qt,kt,vt会被划分为 n h n_h nh个头:
[ q t , 1 ; q t , 2 ; … , q t , n h ] = q t [ k t , 1 ; k t , 2 ; … , k t , n h ] = k t [ v t , 1 ; v t , 2 ; … , v t , n h ] = v t o t , i = ∑ j = 1 t Softmax ( q t , i ⊤ k j , i d h ) v j , i u t = W O [ o t , 1 ; o t , 2 ; … , o t , n h ] \begin{align} &[\textbf{q}_{t,1};\textbf{q}_{t,2};\dots,\textbf{q}_{t,n_h}]=\textbf{q}_t \tag{4}\\ &[\textbf{k}_{t,1};\textbf{k}_{t,2};\dots,\textbf{k}_{t,n_h}]=\textbf{k}_t \tag{5}\\ &[\textbf{v}_{t,1};\textbf{v}_{t,2};\dots,\textbf{v}_{t,n_h}]=\textbf{v}_t \tag{6}\\ &\textbf{o}_{t,i}=\sum_{j=1}^t\text{Softmax}(\frac{\textbf{q}_{t,i}^\top\textbf{k}_{j,i}}{\sqrt{d_h}})\textbf{v}_{j,i} \tag{7}\\ &\textbf{u}_t=W^O[\textbf{o}_{t,1};\textbf{o}_{t,2};\dots,\textbf{o}_{t,n_h}] \tag{8}\\ \end{align} \\ [qt,1;qt,2;…,qt,nh]=qt[kt,1;kt,2;…,kt,nh]=kt[vt,1;vt,2;…,vt,nh]=vtot,i=j=1∑tSoftmax(dhqt,i⊤kj,i)vj,iut=WO[ot,1;ot,2;…,ot,nh](4)(5)(6)(7)(8)
其中 q t , i , k t , i , v t , i ∈ R d h \textbf{q}_{t,i},\textbf{k}_{t,i},\textbf{v}_{t,i}\in\mathbb{R}^{d_h} qt,i,kt,i,vt,i∈Rdh是第 i i i个注意力头的query、key和value, W O ∈ R d × d h n h W^O\in\mathbb{R}^{d\times d_h n_h} WO∈Rd×dhnh是输出投影矩阵。在推理时,所有的key和value都会被缓存来加速推理。对于每个token,MHA需要缓存 2 n h d h l 2n_h d_h l 2nhdhl个元素。
1.2 低秩Key-Value联合压缩
MLA通过低秩联合压缩key和value来减少kv cache:
c t K V = W D K V h t k t C = W U K c t K V v t C = W U V c t K V \begin{align} \textbf{c}_t^{KV}&=W^{DKV}\textbf{h}_t \tag{9}\\ \textbf{k}_t^C&=W^{UK}\textbf{c}_t^{KV} \tag{10}\\ \textbf{v}_t^C&=W^{UV}\textbf{c}_t^{KV} \tag{11}\\ \end{align} \\ ctKVkt