显存
显存占用分析
- Model States
- 模型参数
- 后向传递计算得到的 梯度
- 优化器状态
- Activation
- 前向计算过程中产生的 中间激活
数据类型
- float32(FP32):32 位浮点数,也称为单精度。
- float16(FP16):16 位浮点数,表示范围较小,也被称为半精度。
- bfloat16(BF16):扩大了指数位数,缩小了小数位数,因此表示的范围更大,精度更弱。
一般采用 16 位的表示,那么一个参数占用 2byte,即 2B。
FP16 的精度高,但是表示范围小,容易上溢;
BF16 的表示范围大,但精度低,因此更容易下溢,为了避免溢出问题,提出了混合精度方案。
训练过程
训练大模型时通常会采用 AdamW 优化器,并用 混合精度 训练来加速训练,基于这个前提分析显存占用。
在一次训练迭代中,每个可训练模型参数都会对应 1 个梯度,并对应 2 个优化器状态(Adam 优化器梯度的一阶动量和二阶动量)。
推理过程
在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。模型推理阶段占用的显存要远小于训练阶段。
如果使用 float16 来进行推理,推理阶段模型参数占用的显存大概是 2 Φ 2\mathbf\Phi 2Φ。
模型参数
符号说明:
数学符号 | 定义 |
---|---|
l | 模型层数 |
d | 隐层维度 |
h | 注意力头数 |
b | batch size |
s | 序列长度 |
V | 词表大小 |
μ | 向量的均值 |
σ | 向量的方差 |
从输入到输出的顺序依次计算:
-
Embedding 层:词嵌入矩阵即一个 V → d V \rightarrow d V→d 无偏置线性层,将 V V V 大小的 one-hot 编码映射成 d d d 大小的 token。参数个数 $ Vd $。
- Positional Embedding:如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如 RoPE 和 ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。。
-
l l l 个 block:
-
Self-attention:attention 层中有四个 d → d d \rightarrow d d→d 线性层,包含了权重: W q W_q Wq、 W k W_k Wk、 W v W_v Wv、 W o u t W_{out} Wout 以及各自的偏置。
- 权重矩阵 n 的形状 [ d , d ] [d,d] [d,d],参数个数 d 2 d^2 d2,
- 偏置形状 [ d ] [d] [d],参数个数 d。
- 总计参数量 4 d 2 + 4 d 4d^2+4d 4d2+4d.
-
Layer Normalization:设层输入是 x i n x_{in} xin,
-
layer normalization 公式: x o u t = γ ⊙ α + β x_{out}= \gamma \odot \alpha+\beta xout=γ⊙α+β , α = x i n − μ ( σ 2 + ϵ ) \alpha=\frac{x_{in}−\mu}{\sqrt{(\sigma^2+\epsilon)}} α=(σ2+ϵ)xin−μ。
-
其中 μ \mu μ 表示 x i n x_{in} xin 的均值,$ \sigma$ 表示 x i n x_{in} xin 的方差, ϵ \epsilon ϵ 防止除零, γ \gamma γ 和 β \beta β 是可学习的参数,形状都是 [ d ] [d] [d],参数个数 d d d,一层的参数个数 2 d 2d 2d。
-
因为 self-attention 和 mlp 后各有一层 layer nromalization。所以总参数个数 4 d 4d 4d。
-
-
mlp:共有两个带偏置的线性层,隐层维度默认为 4 d 4d 4d:
- 第一个是 d → 4 d d \rightarrow 4d d→4d ,权重矩阵形状 [ d , 4 d ] [d,4d] [d,4d],偏置形状 [ 4 d ] [4d] [4d],层参数 4 d 2 + 4 d 4d^2+4d 4d2+4d;
- 第二个是 4 d → d 4d \rightarrow d 4d→d ,权重矩阵形状 [ 4 d , d ] [4d,d] [4d,d],偏置形状 [ d ] [d] [d],层参数 4 d 2 + d 4d^2+d 4d2+d;
- mlp 的总参数个数 8 d 2 + 5 d 8d^2+5d 8d2+5d
-
每个 block 的参数个数共计 12 d 2 + 13 d 12d^2+13d 12d2+13d.
-
输出层和 Embedding 层共用参数。
因此,模型共计参数 l ∗ ( 12 d 2 + 13 d ) + V d l∗(12d^2+13d)+Vd l∗(12d2+13d)+Vd
CodeGen 350M 参数
Name Size Embedding transformer.wte.weight torch.Size([51200, 1024]) transformer.h.0.ln_1.weight torch.Size([1024]) transformer.h.0.ln_1.bias torch.Size([1024]) Self-attention transformer.h.0.attn.qkv_proj.weight torch.Size([3072, 1024]) Self-attention-out transformer.h.0.attn.out_proj.weight torch.Size([1024, 1024]) mlp transformer.h.0.mlp.fc_in.weight torch.Size([4096, 1024]) transformer.h.0.mlp.fc_in.bias torch.Size([4096]) transformer.h.0.mlp.fc_out.weight torch.Size([1024, 4096]) transformer.h.0.mlp.fc_out.bias torch.Size([1024])
不同版本 LLaMA 模型的参数量
实际参数量 | 隐藏维度 h | 层数 l | 12 l h 2 12lh^2 12lh2 |
---|---|---|---|
6.7B | 4096 | 32 | 6,442,450,944 |
13.0B | 5120 | 40 | 12,582,912,000 |
32.5B | 6656 | 60 | 31,897,681,920 |
65.2B | 8192 | 80 | 64,424,509,440 |
优化器状态
在训练过程中,模型的每个参数会记录梯度用于更新,此外优化器也会额外记录一些数据,称为 优化器状态。
设模型参数为 $ \mathbf\Phi$, 那么梯度的元素数量为
Φ
\mathbf\Phi
Φ ,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32), 总占用:
2
Φ
+
2
Φ
+
K
Φ
=
(
4
+
K
)
Φ
2\mathbf\Phi +2\mathbf\Phi+K\mathbf\Phi = (4+K)\mathbf\Phi
2Φ+2Φ+KΦ=(4+K)Φ
- 总占用和参数量有关,和输入大小无关。
- 在整个训练过程中都要存在显存中。 模型参数一般只能通过并行切分(Tensor Parallelism/Pipeline Parallism)能减少。优化器状态一般通过 ZeRO 来减少。
- 不同优化器的 K 值不同,算法的中间变量、框架的实现都有可能有一定区别。
AdamW 优化器 对模型中的每个参数记录了两个动量(一阶和二阶动量) m t m_t mt 和 v t v_t vt。
- 在 混合精度训练 中,会使用 float16 的模型参数 进行前向传递和后向传递,计算得到 float16 的梯度;
- 在 优化器 更新模型参数时,会使用 float32 的优化器状态、float32 的梯度、float32 的模型参数 来更新模型参数。
- 使用 AdamW 优化器 和 混合精度训练 来训练参数量为 Φ \mathbf\Phi Φ 的大模型,模型参数、梯度和优化器状态占用的显存大小为 $ 20\mathbf\Phi$ bytes
2 + 4 ⏟ weights + 2 + 4 ⏟ gradients + 4 + 4 ⏟ Adam states = 20 \underbrace{2+4}_{\text {weights}} +\underbrace{2+4}_{\text {gradients}} + \underbrace{4+4}_{\text {Adam states}} = 20 weights 2+4+gradients 2+4+Adam states 4+4=20【注】:有的参考资料中,没有考虑 fp32 的梯度,计算得到总显存为 2 Φ + 2 Φ + 12 Φ = 16 Φ 2\mathbf\Phi +2\mathbf\Phi+12\mathbf\Phi = 16\mathbf\Phi 2Φ+2Φ+12Φ=16Φ,此处参考 Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model
中间激活值
激活(activations) 指的是前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量
中间激活值占用显存 分两个部分分析:Self-Attention 和 MLP,Embedding 没有中间值。
-
Self-Attention 块的中间激活占用显存大小为 11 b s d + 5 b s 2 h 11bsd+5bs^2h 11bsd+5bs2h
-
对于 MLP 块,需要保存的中间激活值为 19 b s d 19bsd 19bsd 。
-
layer norm 需要保存其输入,大小为 2 b s d 2bsd 2bsd,2 个 layer norm 需要保存的中间激活为 $ 4bsd $
-
对于 l l l 层 transformer 模型, 最终合计 l ∗ ( 34 b s d + 5 b s 2 h ) l*(34bsd +5bs^2h) l∗(34bsd+5bs2h)。
- 激活值 与输入数据的大小(批次大小 b 和 序列长度 )成正相关。
- 在训练过程中是变化值,特别是 batch size 大的时候成倍增长很容易导致 OOM。
- 可以通过 重计算、并行切分 策略减少。
在一次训练迭代中
- 模型参数(或梯度)占用的显存大小 只与 模型参数量 和 参数数据类型 有关,与输入数据的大小是没有关系的。
- 优化器状态占用的显存大小 与 优化器类型 有关,与 模型参数量 有关,与输入数据的大小无关。
- 中间激活值 与输入数据的大小(批次大小 b b b 和 序列长度 s s s)是成正相关的,随着 批次大小 b b b 和 序列长度 s s s 的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足 OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。
以 GPT3-175B 为例,直观对比模型参数与中间激活的显存大小。GPT3 的模型配置如下。假设采用混合精度训练,模型参数和中间激活都采用 float16 数据类型,每个元素占 2 个 bytes。
模型名 | 参数量 | 层数 | 隐藏维度 | 注意力头数 |
---|---|---|---|---|
GPT3 | 175B | 96 | 12288 | 96 |
-
GPT3 的模型参数量为 175B,占用的显存大小为 2 ∗ 175 ∗ 1 0 9 bytes = 350 GB 2*175*10^9 \text{bytes}=350 \text{GB} 2∗175∗109bytes=350GB 。GPT3 模型需要占用 350GB 的显存。
-
GPT3 的序列长度 l l l 为 2048 。对比不同的批次大小 b b b 占用的中间激活:
-
当 l l l = 1 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 275 , 414 , 777 , 856 bytes ≈ 275 GB (34bsd+5bs^2h)∗l=275,414,777,856 \text{bytes}\approx 275 \text{GB} (34bsd+5bs2h)∗l=275,414,777,856bytes≈275GB,大约是模型参数显存的 0.79 倍。
-
当 l l l = 64 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 17626545782 bytes ≈ 17.6 TB (34bsd+5bs^2h)∗l=17626545782 \text{bytes}\approx 17.6 \text{TB} (34bsd+5bs2h)∗l=17626545782bytes≈17.6TB,大约是模型参数显存的 50 倍。
-
当 l l l = 128 时,中间激活占用显存为, $ (34bsd+5bs^2h)∗l=35253091565568 \text{bytes}\approx 35.3 \text{TB}$ 大约是模型参数显存的 101 倍。
可以看到随着批次大小 b b b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用 激活重计算 技术来减少中间激活,理论上可以将中间激活显存从 O ( n ) O(n) O(n) 减少到 O ( n ) O(\sqrt{n}) O(n) ,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。
-