大模型之模型训练篇
注意:文章内容参考了斯坦福CS324 - Large Language Models课程,以及[Datawhale的一起学相关课程中的内容]
老规矩先讲总结哈~
总结:在这一篇中主要是讨论了如何训练大模型(哈哈,并没有具体的案例哈)。在如何训练中分为了目标函数
和优化算法
两个部分。在目标函数部分是围绕encoder和decoder讲解了几个模型。优化算法这里则是讨论了大模型训练的时候如何进行优化。在优化算法这里收益匪浅,谢谢佬们!!!
1. 目标函数
顾名思义~,目标函数就是要找到我们的目标,TransFormer结构的语言模型通常由三类,每一类对应的典型模型都是不一样的,目标函数也是不一样的,接下来将分开讲解。
三类语言模型的目标函数:
- 只包含解码器(Decoder-only)的模型(例如,GPT-3):计算单向上下文嵌入(contextual embeddings),一次生成一个token
- 只包含编码器(Encoder-only)的模型(例如,BERT):计算双向上下文嵌入
- 编码器解码器(Encoder-decoder)模型(例如,T5):编码输入,解码输出
任何模型将token序列映射到上下文嵌入中(例如,LSTM、Transformers):
ϕ : V L → R d × L . \phi : V^L \to \mathbb{R}^{d \times L}. ϕ:VL→Rd×L.
[ the , mouse , ate , the , cheese ] ⇒ ϕ [ ( 1 0.1 ) , ( 0 1 ) , ( 1 1 ) , ( 1 − 0.1 ) , ( 0 − 1 ) ] . \left[\text{the}, \text{mouse}, \text{ate}, \text{the}, \text{cheese}\right] \stackrel{\phi}{\Rightarrow} \left[\binom{1}{0.1}, \binom{0}{1}, \binom{1}{1}, \binom{1}{-0.1}, \binom{0}{-1} \right]. [the,mouse,ate,the,cheese]⇒ϕ[(0.11),(10),(11),(−0.11),(−10)].
1.1 Decoder-only模型
首先定义一个条件分布
p
(
x
i
∣
x
1
:
i
−
1
)
p(x_i \mid x_{1:i-1} )
p(xi∣x1:i−1)
定义如下:
- 将 x 1 : i − 1 x_{1:i-1} x1:i−1映射到上下文的embedding(嵌入)中
- 应用嵌入矩阵 E ∈ R V × d E \in \R^{V \times d} E∈RV×d 来获得每个token的得分 E ϕ ( x 1 : i − 1 ) i − 1 E \phi(x_{1:i-1})_{i-1} Eϕ(x1:i−1)i−1 。
- 对其进行指数化和归一化,得到预测 x i x_i xi的 分布。
简单讲:
p
(
x
i
+
1
∣
x
1
:
i
)
=
s
o
f
t
m
a
x
(
E
ϕ
(
x
1
:
i
)
i
)
.
p(x_{i+1} \mid x_{1:i}) = softmax(E \phi(x_{1:i})_i).
p(xi+1∣x1:i)=softmax(Eϕ(x1:i)i).
1.1.1 最大似然
设
θ
\theta
θ 是大语言模型的所有参数。设
D
D
D 是由一组序列组成的训练数据。
然后,我们可以遵循最大似然原理,定义以下负对数似然目标函数:
O ( θ ) = ∑ x ∈ D − log p θ ( x ) = ∑ x ∈ D ∑ i = 1 L − log p θ ( x i ∣ x 1 : i − 1 ) . O(\theta) = \sum_{x \in D} - \log p_\theta(x) = \sum_{x \in D} \sum_{i=1}^L -\log p_\theta(x_i \mid x_{1:i-1}). O(θ)=x∈D∑−logpθ(x)=x∈D∑i=1∑L−logpθ(xi∣x1:i−1).
并且,有很多的方法可以有效地优化这一目标函数。
1.2 Encoder-only 模型
1.2.1 单向到双向
使用上述最大似然可以训练得到Decoder-only模型,它会产生(单向)上下文嵌入。但如果我们不需要生成,我们需要更强的双向上下文嵌入。
1.2.2 BERT
我们首先介绍BERT的目标函数,它包含以下两个部分:
- 掩码语言模型(Masked language modeling)
- 下一句预测(Next sentence prediction)
以自然语言推理(预测隐含、矛盾或中性)任务中的序列为例: