LLama3 模型学习笔记
Note:本文参考了 transformers GitHub 仓库 中对 LLama 的实现
整体结构
其中,Decode 部分可重复多次,单个 Decode 部分可表示为:
RMSNorm
μ = mean ( hidden_states 2 , axis = − 1 , keepdim = true ) hidden_states = hidden_states × 1 μ + ϵ output = self.weight × hidden_states \mu = \text{mean}(\text{hidden\_states}^2, \text{axis} = -1, \text{keepdim}=\text{true}) \\ \text{hidden\_states} = \text{hidden\_states} \times \frac{1}{\sqrt{\mu + \epsilon}} \\ \text{output} = \text{self.weight} \times \text{hidden\_states} μ=mean(hidden_states2,axis=−1,keepdim=true)hidden_states=hidden_states×μ+ϵ1output=self.weight×hidden_states
其中, μ \mu μ 表示对 hidden_states 求得的平均值, ϵ \epsilon ϵ 是一个很小的数,目的是防止 μ = 0 \mu = 0 μ=0 时分母为0
Multi-head Attention
位置编码采用的是 Rope 旋转位置编码,首先需要生成一个逆频率向量:
inv_freq
i
=
1
base
2
i
d
\text{inv\_freq}_i = \frac{1}{\text{base}^{\frac{2i}{d}}}
inv_freqi=based2i1
其中,d 代表总维度,base 通常取 10, 000。之后需要计算频率矩阵:
freqs
pos
,
i
=
pos
⋅
inv_freq
i
emb
pos
,
i
=
{
c
o
s
(
freqs
pos
,
i
/
2
)
,
i
是偶数
s
i
n
(
freqs
pos
,
(
i
−
1
)
/
2
)
,
i
是奇数
\text{freqs}_{\text{pos},i} = \text{pos} \cdot \text{inv\_freq}_i \\ \text{emb}_{\text{pos}, i} = \left\{ \begin {align*} cos(\text{freqs}_{\text{pos}, i/2}), &i \text{是偶数} \\ sin(\text{freqs}_{\text{pos}, (i-1)/2}), &i \text{是奇数} \end {align*} \right.
freqspos,i=pos⋅inv_freqiembpos,i={cos(freqspos,i/2),sin(freqspos,(i−1)/2),i是偶数i是奇数
最后,应用旋转变换:
x
′
=
x
⋅
c
o
s
(
e
m
b
)
+
x
⊥
⋅
s
i
n
(
e
m
b
)
x' = x \cdot cos(emb) + x_{\bot} \cdot sin(emb)
x′=x⋅cos(emb)+x⊥⋅sin(emb)
MLP 部分
LLama 论文中提到的是使用 SwiGLU 激活函数,但是 transformers 的实现中似乎并没有采用这种激活函数。
下游任务
激活函数,但是 transformers 的实现中似乎并没有采用这种激活函数。
下游任务
本质就是接一个线形层,利用不同的数据与损失函数针对不同下游任务进行训练。