llama3 结构详解

1. Llama3 整体结构

  llama3 的整体结构还是延续transformer decoder 架构,其整体架构如下图左侧蓝色虚线框中所示。模型结构并不复杂,其主要组件为32个Transformer Block(32 为meta llama3 中的默认值)(见下图红色虚线框中所示)。

在这里插入图片描述

注 1 注_1 1: 下一节中会参照上图中 红色圆形序号 讲解各模块。
注 2 注_2 2: llama3的RoPE算法被拆成了3个方法来实现,上图中的模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。

2. 模块详解

2.1 模块1: Embeddings

  llama3 的embedding 使用的是VocabParallelEmbedding这个类进行的向量转换,这个类是meta的fairscale包中的一个类,可以理解为对torch.nn.embedding做了并行化。

2.2 模块2: RoPE

  前文中已经提及llama3的RoPE算法被拆成了3个方法来实现,模块2只包含了一个方法,另两个方法是在Attention模块(模块5)中进行的调用。本小节具体按照RoPE的原始论文来讲解,主要阐述RoPE的算法原理。

2.2.1 从一个2维的例子说起 RoPE

  我们知道,寻找位置编码的基本思路是 输入位置编码经过特征提取的核心算法后的值,应能反应出两个位置之间的先后顺序(这点不是必要的)和相对位置信息。(《Transformer(二)–论文理解:transformer 结构详解》 2.1节 中有简单说明),RoPE的原始论文中给出了一个数学表达,如下式:
< f q ( x m , m ) , f k ( x n , n ) > = g ( x m , x n , m − n ) (2.1) <f_q(x_m,m),f_k(x_n,n)>=g(x_m,x_n,m-n) \tag{2.1} <fq(xm,m),fk(xn,n)>=g(xm,xn,mn)(2.1)
   f q ( x m , m ) f_q(x_m,m) fq(xm,m) f k ( x n , n ) f_k(x_n,n) fk(xn,n)分别为query和key。关于 g ( x m , x n , m − n ) g(x_m,x_n,m-n) g(xm,xn,mn),我理解为输入位置变量的计算函数,和我们使用特征抽取器相关,在transformer架构里,我们一般采用点积计算attention score(见公式2.7),所以, g ( x m , x n , m − n ) g(x_m,x_n,m-n) g(xm,xn,mn)的计算实质上应该还是计算点积(公式左边就是点积,我这里只是再啰嗦的说下为什么时点积形式)。这个函数的参数有三个, x m , x n x_m,x_n xm,xn是词向量,还有一个是 m − n m-n mn,这里之所以是 m − n m-n mn而不是 m m m n n n,是因为我们的特征抽取函数(点积注意力)已经做为已知条件固定了,所以我们要在数据进行特征抽取函数前进行变换。
  我们的目的就是找到一个这样的变换函数 f { q , k } f_{\{q,k\}} f{q,k}能表达 f q ( x m , m ) f_q(x_m,m) fq(xm,m) f k ( x n , n ) f_k(x_n,n) fk(xn,n),使 f q f_q fq f k f_k fk做点积操作后能保留 m − n m-n mn的信息。当然我们找到了,见公式2.2

  RoPE的论文中是先从2D情况下举例说明我们找到的 f ( x ) f(x) f(x)的,如下,当 d = 2 d=2 d=2时:

f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = R e [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ] (2.2) f_q(x_m,m) = (\pmb{W}_{q}x_m)e^{im\theta} \\ f_k(x_n,n) = (\pmb{W}_{k}x_n)e^{in\theta} \\ g(x_m,x_n,m-n) = Re[(\pmb{W_q}x_m)(\pmb{W}_kx_n)^{*}e^{i(m-n)\theta}] \tag{2.2} fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ](2.2)

  其中 R e [ ⋅ ] Re[ \cdot ] Re[]是复数的实部, ( W k x n ) ∗ (\pmb{W}_{k}x_n)^{*} (Wkxn)表示 ( W n ) (\pmb{W}_n) (Wn)的共轭复数。
θ ∈ R \theta \in \mathbb{R} θR 是一个预设的非零常数。我们可以进一步将 f { q , k } f_{\{q,k\}} f{q,k}写成乘法矩阵:
f { q , k } ( x m , m ) = ( c o s   m θ − s i n   m θ s i n   m θ c o s   m θ ) ( W { q , k } ( 11 ) W { q , k } ( 12 ) W { q , k } ( 21 ) W { q , k } ( 22 ) ) ( x m ( 1 ) x m ( 2 ) ) (2.3) f_{\{q,k\}}(x_m,m)= \left( \begin{matrix} cos\ m\theta & -sin\ m\theta \\ sin\ m\theta & cos\ m\theta \\ \end{matrix} \right) \left( \begin{matrix} W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\ W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}} \\ \end{matrix} \right) \left( \begin{matrix} x^{(1)}_{m} \\ x^{(2)}_{m} \end{matrix} \right) \tag{2.3} f{q,k}(xm,m)=(cos mθsin mθsin mθcos mθ)(W{q,k}(11)W{q,k}(21)W{q,k}(12)W{q,k}(22))(xm(1)xm(2))(2.3)

其中, ( x m ( 1 ) , x m ( 2 ) ) (x^{(1)}_{m},x^{(2)}_{m}) (xm(1),xm(2)) x m x_m xm在二维坐标系中的表示。同样的, g g g也可以看作一个矩阵,因此可以在2维情况下求解公式(2.1)。

2.2.2 RoPE的一般形式

  为了将我们在2D中的结果推广到任意的 x i ∈ R d x_i \in \mathbb{R}^d xiRd,我们将d维空间划分为d/2个子空间,并根据内积的线性性质将它们组合起来,将 f { q , k } ( x m , n ) f_{\{q,k\}}(x_m,n) f{q,k}(xm,n)转化为:
f { q , k } ( x m , m ) = R Θ , m d W { q , k } x m (2.4) f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m \tag{2.4} f{q,k}(xm,m)=RΘ,mdW{q,k}xm(2.4)

  其中, W { q , m } \pmb{W}_{\{q,m\}} W{q,m} 表示与query和key 所对应的转换矩阵 , x m x_m xm 为输入向量, R Θ , m d \pmb{R}^d_{\Theta,m} RΘ,md为旋转矩阵,具体如下:
R Θ , m d = ( c o s   m θ 1 − s i n   m θ 1 0 0 ⋯ 0 0 s i n   m θ 1 c o s   m θ 1 0 0 ⋯ 0 0 0 0 c o s   m θ 2 − s i n   m θ 2 ⋯ 0 0 0 0 s i n   m θ 2 c o s   m θ 2 ⋯ 0 0 ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ 0 0 0 0 ⋯ c o s   m θ d / 2 − s i n   m θ d / 2 0 0 0 0 ⋯ s i n   m θ d / 2 c o s   m θ d / 2 ) (2.5) \pmb{R}^{d}_{\Theta,m}= \left( \begin{matrix} cos\ m\theta_1 & -sin\ m\theta_1 &0 &0 & \cdots &0 &0 \\ sin\ m\theta_1 & cos\ m\theta_1 &0 &0 & \cdots &0 &0 \\ 0 & 0 & cos\ m\theta_2 & -sin\ m\theta_2 & \cdots &0 &0 \\ 0 & 0 & sin\ m\theta_2 & cos\ m\theta_2 & \cdots &0 &0 \\ \vdots & \vdots &\vdots & \vdots & \ddots & \vdots & \vdots \\ 0 & 0 &0 &0 & \cdots & cos\ m\theta_{d/2} & -sin\ m\theta_{d/2} \\ 0 & 0 &0 &0 & \cdots & sin\ m\theta_{d/2} & cos\ m\theta_{d/2} \\ \end{matrix} \right) \tag{2.5} RΘ,md= cos mθ1sin mθ10000sin mθ1cos mθ1000000cos mθ2sin mθ20000sin mθ2cos mθ2000000cos mθd/2sin mθd/20000sin mθd/2cos mθd/2 (2.5)

Θ = { θ i = 1000 0 − 2 ( i − 1 ) / d , i ∈ [ 1 , 2 , . . . , d / 2 ] } (2.6) \Theta=\{ \theta_i = 10000^{-2(i-1)/d}, i \in [1,2,...,d/2] \} \tag{2.6} Θ={θi=100002(i1)/d,i[1,2,...,d/2]}(2.6)

2.2.3 RoPE的理解

  这里我们把我们求出的 f { q , k } ( x m , m ) = R Θ , m d W { q , k } x m f_{\{q,k\}}(x_m,m)=\pmb{R}^{d}_{\Theta, m}\pmb{W}_{\{q,k\}}x_m f{q,k}(xm,m)=RΘ,mdW{q,k}xm代入attention score的计算公式
a m , n = exp ⁡ ( q m T k n d ) ∑ j = 1 N exp ⁡ ( q m T k j d ) (2.7) a_{m,n}=\frac{\exp{(\frac{q^{T}_mk_n}{\sqrt{d}})}}{\sum^N_{j=1}{\exp{(\frac{q^{T}_mk_j}{\sqrt{d}})}}} \tag{2.7} am,n=j=1Nexp(d qmTkj)exp(d qmTkn)(2.7)

这里我们只需要看 q m T k m q^T_{m}k_m qmTkm即可,公式的其余部分不会改变结果形式。把公式2.4代入2.7

q m T k n = ( R Θ , m d W q x m ) T ( R Θ , n d W k x n ) = x T W q R Θ , n − m d W k x n (2.8) q^{T}_{m}k_n=(\pmb{R}^d_{\Theta,m}\pmb{W}_qx_m)^T(\pmb{R}^d_{\Theta,n}\pmb{W}_kx_n)=x^T\pmb{W}_{q}R^d_{\Theta,n-m}\pmb{W}_kx_n \tag{2.8} qmTkn=(RΘ,mdWqxm)T(RΘ,ndWkxn)=xTWqRΘ,nmdWkxn(2.8)

其中, R Θ , n − m d = ( R Θ , m d ) T R Θ , n d \pmb{R}^d_{\Theta,n-m} = (\pmb{R}^d_{\Theta,m})^T\pmb{R}^d_{\Theta,n} RΘ,nmd=(RΘ,md)TRΘ,nd,注意 R Θ d \pmb{R}^d_{\Theta} RΘd是一个正交矩阵,这保证了位置信息在处理过程中的稳定性。此外,由于 R Θ d \pmb{R}^d_{\Theta} RΘd的稀疏性,式(2.8)的计算效率不高,作者在理论上提供了另一种实现。

2.3 模块3: Transformer Block

  Transformer Block 模块是llama3的核心模块,或者说,llama3为Transformer Block模块堆叠而成。Transformer Block有模块4、5、6、7组成,具体内容见对应模块。

2.4 模块4: RMSNorm

  RSMNorm 是在 layer normalization 基础上优化而来,所以先简单回顾下layer normalization。(详细介绍见《Transformer(二)–论文理解:transformer 结构详解》 2.4节)
  layer normalization 是根据下面的公式对 x x x的分布进行调整。
x = a ∗ x − x ‾ s t d + e p s + b (2.9) x = a * \frac{x - \overline{x}}{std + eps} + b \tag{2.9} x=astd+epsxx+b(2.9)
其中, x ‾ \overline{x} x是均值, s t d std std是标准差, e p s eps eps为一个很小的数,防止分母为零。 a a a b b b为参数, b b b可以为零。
  我们现在来看看RMSNorm做了什么优化呢,其实他对上面的试子 x = a ∗ x − x ‾ s t d + e p s + b x = a * \frac{x - \overline{x}}{std + eps} + b x=astd+epsxx+b进行了简化。RMSNorm的计算公式如下:
a ‾ i = a i R M S ( a ) g i , w h e r e R M S ( a ) = 1 n Σ i = 1 n a i 2 (2.10) \overline{a}_i=\frac{a_i}{RMS(a)}g_{i}, \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\Sigma^n_{i=1}{a^{2}_{i}}} \tag{2.10} ai=RMS(a)aigi,whereRMS(a)=n1Σi=1nai2 (2.10)

  从上式可以看出,RMSNorm移除了LayerNorm中的均值项(原式中的 x ‾ \overline{x} x项), s t d std std的计算中,也没有做减去均值的操作( s t d = 1 n Σ i = 1 n ( a i − a ‾ ) std=\sqrt{\frac{1}{n}\Sigma^n_{i=1}({a_i - \overline{a})}} std=n1Σi=1n(aia) )。这种简化在计算效率上有一定提高,且原始论文也说了,在效果上没有明显影响。

下面附上meta llama3中RMSNorm的源码,方便大家理解。

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.5 模块5: Attention

  llama3的attention模块主要做了4部分工作,分别是RoPE计算、分注意力分组机制实现、点积注意力计算 及 kv缓存策略实现。其中RoPE的计算在模块2中已经讲解,这里不在赘述。下文对GQA,点积注意力计算及KV缓存进行简单的讲解。

2.5.1 分组注意力机制(GQA)

  llama3中的attention模块与《Attention is all you need》中使用的attention技术有些许优化。同样是使用Scaled Dot-Product Attention来计算attention score,但分组优化这块没有延续使用MHA(Multi-head Attention)技术,而是使用了GQA(Grouped-Query Attention)分组技术。具体的Scaled Dot-Product Attention 与MHA我之前在《Transformer(二)–论文理解:transformer 结构详解》一文的2.2节中,已经写的非常详细了,所以这里不再展开,只讲解下GQA。

  我们知道,在MHA中,由于每个head都有独立的键和值,内存和计算成本较高,特别是在处理长序列或大批量数据时。然后就有大牛Noam Shazeer提出了MQA(Multi Query Attention)方法,将原来的h个KV对缩减为1个,所有query只使用一个共享的KV对,这种改造虽然大大减少了显存消耗,但其特征捕捉能力也受到影响。因此又提出了GQA(Grouped-Query Attention ), 将query 进行分组,每组共享一个KV对。下面是GQA原始论文中给出的对比图。
在这里插入图片描述
  为了清楚,这理举一个具体的例子:假设有一个token 序列 [ T 1 , T 2 , T 3 , T 4 , T 5 , T 6 , T 7 , T 8 ] [T_1,T_2,T_3,T_4,T_5,T_6,T_7,T_8] [T1,T2,T3,T4,T5,T6,T7,T8], 我们把这个token序列分成两个组来计算GQA。

  • step1: 分组
    G r o u p 1 = [ T 1 , T 2 , T 3 , T 4 ] G r o u p 2 = [ T 5 , T 6 , T 7 , T 8 ] Group_1=[T1,T2,T3,T4] \\ Group_2=[T5,T6,T7,T8] Group1=[T1,T2,T3,T4]Group2=[T5,T6,T7,T8]

  • step2:计算分组后的注意力
    每个组内部计算注意力分数。为简单起见,我们假设我们有以下简化的注意力机制:
    A t t e n t i o n   S c o r e ( Q i , K i ) = Q i ⋅ K i d k Attention\ Score(Q_i,K_i) = \frac{Q_i \cdot K_i}{\sqrt{d_k}} Attention Score(Qi,Ki)=dk QiKi
    其中 Q Q Q 是query, K K K 是 key, d k d_k dk​ 是键的维度。

    • 对于 G r o u p 1 Group_1 Group1
      • 计算标记 [ T 1 、 T 2 、 T 3 、 T 4 ] [T1、T2、T3、T4] [T1T2T3T4] 的注意力得分。
      • 这会生成一个 4×4 注意力矩阵。
    • 对于 G r o u p 2 Group_2 Group2:
      • 计算标记 [ T 5 、 T 6 、 T 7 、 T 8 ] [T5、T6、T7、T8] [T5T6T7T8] 的注意力得分。
      • 这会生成另一个 4×4 注意力矩阵。
  • step3: 共享组内注意力分数
    在每个组中,注意力分数是共享的。例如,第 1 组的注意力矩阵可能如下所示:
    A t t e n t i o n   S c o r e   G r o u p 1 = [ a 11 a 12 a 13 a 14 a 21 a 22 a 23 a 24 a 31 a 32 a 33 a 34 a 41 a 42 a 43 a 44 ] Attention\ Score_{\ Group_1} = \left[ \begin{matrix} a_{11} & a_{12} & a_{13} &a_{14} \\ a_{21} & a_{22} & a_{23} &a_{24} \\ a_{31} & a_{32} & a_{33} &a_{34} \\ a_{41} & a_{42} & a_{43} &a_{44} \\ \end{matrix} \right] Attention Score Group1= a11a21a31a41a12a22a32a42a13a23a33a43a14a24a34a44
    对于第二个分组:
    A t t e n t i o n   S c o r e   G r o u p 2 = [ a 51 a 52 a 53 a 54 a 61 a 62 a 63 a 64 a 71 a 72 a 73 a 74 a 81 a 82 a 83 a 84 ] Attention\ Score_{\ Group_2} = \left[ \begin{matrix} a_{51} & a_{52} & a_{53} &a_{54} \\ a_{61} & a_{62} & a_{63} &a_{64} \\ a_{71} & a_{72} & a_{73} &a_{74} \\ a_{81} & a_{82} & a_{83} &a_{84} \\ \end{matrix} \right] Attention Score Group2= a51a61a71a81a52a62a72a82a53a63a73a83a54a64a74a84

  • step4:注意力计算
    组中的每个标记根据计算出的分数关注其组中的其他标记(具体计算方法见2.5.2节)。例如, T 1 T_1 T1 将使用第 1 组注意力矩阵第一行的分数关注 T 2 T_2 T2 T 3 T_3 T3 T 4 T_4 T4

  • step5:合并结果
    在计算每个组内的注意力后,我们将结果合并以形成最终的输出序列。每个标记的输出是它关注的标记值的加权和。

  • 优点总结 :对查询进行分组有以下两个优点

    • 降低复杂度:我们不再计算 8×8 矩阵的注意力,而是计算两个 4×4 矩阵,从而显著减少了计算量。
    • 可扩展性:此方法更适合长序列,因为注意力计算随组大小而非整个序列长度二次增长。

2.5.2 注意力计算(Scaled Dot-Product Attention)

  llama3 计算attention score时,使用了与《attention is all you need》一文中相同的计算方法,即点积注意力方法(Scaled Dot-Product Attention),由于Scaled Dot-Product Attention在《Transformer(二)–论文理解:transformer 结构详解》 一文中的2.2.1章节有详细的讲解,这里就不再展开。

2.5.3 KV缓存

   llama3在计算 attention 时采用了kv cache策略。此策略的思想是缓存每个时间步的key和value的值,在推理阶段,由于模型是自回归模式生成文本,所以当我们对过往时间步有缓存结果时,会减少计算量,提高解码效率。

下面是llama3中Attention类的源码,大家可以参考理解

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
		.
		.
		.
    

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        # 以下是Scaled Dot-Product Attention的计算
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

2.6 模块6: ADD

   此模块做了个类似残差的操作,但与残差不同的是,不是用输入减去输出,而是用输入加上输出。具体操作就是把模块4的输入与模块5的输出做加法运算。

2.7 模块7: FFN

  由3个Linear组成的FeedForward网络,这里的激活函数使用的siLU。siLU的数学公式如下:
s i l u ( x ) = x ∗ σ ( x ) ,    w h e r e   σ ( x )   i s   t h e   l o g i s t i c   s i g m o i d . silu(x)=x*\sigma(x), \ \ where\ \sigma(x)\ is\ the\ logistic\ sigmoid. silu(x)=xσ(x),  where σ(x) is the logistic sigmoid.

函数的激活曲线如下图:
在这里插入图片描述
在里注意下,siLU 还有一个名字叫“swish function”,这个在 pytorch 的官方文档中有说明。

下面给出主要源码。


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        .
        .
        .
  

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

2.8 模块8: Linear

  此模块的目的是把模型中 decoder的输出从 d m o d e l d_{model} dmodel维度映射到词表大小的维度。下面是meta llama中的linear层的初始化。

 self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值