显存计算


显存

显存占用分析

  • Model States
  • 模型参数
  • 后向传递计算得到的 梯度
  • 优化器状态
  • Activation
  • 前向计算过程中产生的 中间激活

数据类型

  • float32(FP32):32 位浮点数,也称为单精度。
  • float16(FP16):16 位浮点数,表示范围较小,也被称为半精度。
  • bfloat16(BF16):扩大了指数位数,缩小了小数位数,因此表示的范围更大,精度更弱。

一般采用 16 位的表示,那么一个参数占用 2byte,即 2B。

FP16 的精度高,但是表示范围小,容易上溢;

BF16 的表示范围大,但精度低,因此更容易下溢,为了避免溢出问题,提出了混合精度方案。

训练过程

训练大模型时通常会采用 AdamW 优化器,并用 混合精度 训练来加速训练,基于这个前提分析显存占用。

在一次训练迭代中,每个可训练模型参数都会对应 1 个梯度,并对应 2 个优化器状态(Adam 优化器梯度的一阶动量和二阶动量)。

推理过程

在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。模型推理阶段占用的显存要远小于训练阶段

如果使用 float16 来进行推理,推理阶段模型参数占用的显存大概是 2 Φ 2\mathbf\Phi 2Φ

模型参数

符号说明:

数学符号定义
l模型层数
d隐层维度
h注意力头数
bbatch size
s序列长度
V词表大小
μ向量的均值
σ向量的方差

从输入到输出的顺序依次计算:

  • Embedding 层:词嵌入矩阵即一个 V → d V \rightarrow d Vd 无偏置线性层,将 V V V 大小的 one-hot 编码映射成 d d d 大小的 token。参数个数 $ Vd $。

    • Positional Embedding:如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如 RoPE 和 ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。。
  • l l l 个 block:

  • Self-attention:attention 层中有四个 d → d d \rightarrow d dd 线性层,包含了权重: W q W_q Wq W k W_k Wk W v W_v Wv W o u t W_{out} Wout 以及各自的偏置。

    • 权重矩阵 n 的形状 [ d , d ] [d,d] [d,d],参数个数 d 2 d^2 d2
    • 偏置形状 [ d ] [d] [d],参数个数 d。
    • 总计参数量 4 d 2 + 4 d 4d^2+4d 4d2+4d.
  • Layer Normalization:设层输入是 x i n x_{in} xin

    • layer normalization 公式: x o u t = γ ⊙ α + β x_{out}= \gamma \odot \alpha+\beta xout=γα+β , α = x i n − μ ( σ 2 + ϵ ) \alpha=\frac{x_{in}−\mu}{\sqrt{(\sigma^2+\epsilon)}} α=(σ2+ϵ) xinμ

    • 其中 μ \mu μ 表示 x i n x_{in} xin 的均值,$ \sigma$ 表示 x i n x_{in} xin 的方差, ϵ \epsilon ϵ 防止除零, γ \gamma γ β \beta β 是可学习的参数,形状都是 [ d ] [d] [d],参数个数 d d d,一层的参数个数 2 d 2d 2d

    • 因为 self-attention 和 mlp 后各有一层 layer nromalization。所以总参数个数 4 d 4d 4d

  • mlp:共有两个带偏置的线性层,隐层维度默认为 4 d 4d 4d

    • 第一个是 d → 4 d d \rightarrow 4d d4d ,权重矩阵形状 [ d , 4 d ] [d,4d] [d,4d],偏置形状 [ 4 d ] [4d] [4d],层参数 4 d 2 + 4 d 4d^2+4d 4d2+4d
    • 第二个是 4 d → d 4d \rightarrow d 4dd ,权重矩阵形状 [ 4 d , d ] [4d,d] [4d,d],偏置形状 [ d ] [d] [d],层参数 4 d 2 + d 4d^2+d 4d2+d
    • mlp 的总参数个数 8 d 2 + 5 d 8d^2+5d 8d2+5d
  • 每个 block 的参数个数共计 12 d 2 + 13 d 12d^2+13d 12d2+13d.

  • 输出层和 Embedding 层共用参数。

因此,模型共计参数 l ∗ ( 12 d 2 + 13 d ) + V d l∗(12d^2+13d)+Vd l(12d2+13d)+Vd

CodeGen 350M 参数

NameSize
Embeddingtransformer.wte.weighttorch.Size([51200, 1024])
transformer.h.0.ln_1.weighttorch.Size([1024])
transformer.h.0.ln_1.biastorch.Size([1024])
Self-attentiontransformer.h.0.attn.qkv_proj.weighttorch.Size([3072, 1024])
Self-attention-outtransformer.h.0.attn.out_proj.weighttorch.Size([1024, 1024])
mlptransformer.h.0.mlp.fc_in.weighttorch.Size([4096, 1024])
transformer.h.0.mlp.fc_in.biastorch.Size([4096])
transformer.h.0.mlp.fc_out.weighttorch.Size([1024, 4096])
transformer.h.0.mlp.fc_out.biastorch.Size([1024])

不同版本 LLaMA 模型的参数量

实际参数量隐藏维度 h层数 l 12 l h 2 12lh^2 12lh2
6.7B4096326,442,450,944
13.0B51204012,582,912,000
32.5B66566031,897,681,920
65.2B81928064,424,509,440
优化器状态

在训练过程中,模型的每个参数会记录梯度用于更新,此外优化器也会额外记录一些数据,称为 优化器状态

设模型参数为 $ \mathbf\Phi$, 那么梯度的元素数量为 Φ \mathbf\Phi Φ ,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32), 总占用
2 Φ + 2 Φ + K Φ = ( 4 + K ) Φ 2\mathbf\Phi +2\mathbf\Phi+K\mathbf\Phi = (4+K)\mathbf\Phi 2Φ+2Φ+KΦ=(4+K)Φ

  1. 总占用和参数量有关,和输入大小无关。
  2. 在整个训练过程中都要存在显存中。 模型参数一般只能通过并行切分(Tensor Parallelism/Pipeline Parallism)能减少。优化器状态一般通过 ZeRO 来减少。
  3. 不同优化器的 K 值不同,算法的中间变量、框架的实现都有可能有一定区别。

AdamW 优化器 对模型中的每个参数记录了两个动量(一阶和二阶动量) m t m_t mt v t v_t vt

  • 混合精度训练 中,会使用 float16 的模型参数 进行前向传递和后向传递,计算得到 float16 的梯度
  • 优化器 更新模型参数时,会使用 float32 的优化器状态float32 的梯度float32 的模型参数 来更新模型参数。
  • 使用 AdamW 优化器混合精度训练 来训练参数量为 Φ \mathbf\Phi Φ 的大模型,模型参数、梯度和优化器状态占用的显存大小为 $ 20\mathbf\Phi$ bytes
    2 + 4 ⏟ weights + 2 + 4 ⏟ gradients + 4 + 4 ⏟ Adam states = 20 \underbrace{2+4}_{\text {weights}} +\underbrace{2+4}_{\text {gradients}} + \underbrace{4+4}_{\text {Adam states}} = 20 weights 2+4+gradients 2+4+Adam states 4+4=20

【注】:有的参考资料中,没有考虑 fp32 的梯度,计算得到总显存为 2 Φ + 2 Φ + 12 Φ = 16 Φ 2\mathbf\Phi +2\mathbf\Phi+12\mathbf\Phi = 16\mathbf\Phi 2Φ+2Φ+12Φ=16Φ,此处参考 Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model

中间激活值

激活(activations) 指的是前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量

中间激活值占用显存 分两个部分分析:Self-Attention 和 MLP,Embedding 没有中间值。

  • Self-Attention 块的中间激活占用显存大小为 11 b s d + 5 b s 2 h 11bsd+5bs^2h 11bsd+5bs2h

  • 对于 MLP 块,需要保存的中间激活值为 19 b s d 19bsd 19bsd

  • layer norm 需要保存其输入,大小为 2 b s d 2bsd 2bsd,2 个 layer norm 需要保存的中间激活为 $ 4bsd $

  • 对于 l l l 层 transformer 模型, 最终合计 l ∗ ( 34 b s d + 5 b s 2 h ) l*(34bsd +5bs^2h) l(34bsd+5bs2h)

  1. 激活值 与输入数据的大小(批次大小 b序列长度 )成正相关。
  2. 在训练过程中是变化值,特别是 batch size 大的时候成倍增长很容易导致 OOM。
  3. 可以通过 重计算并行切分 策略减少。

在一次训练迭代中

  • 模型参数(或梯度)占用的显存大小 只与 模型参数量参数数据类型 有关,与输入数据的大小是没有关系的。
  • 优化器状态占用的显存大小优化器类型 有关,与 模型参数量 有关,与输入数据的大小无关。
  • 中间激活值 与输入数据的大小(批次大小 b b b序列长度 s s s)是成正相关的,随着 批次大小 b b b序列长度 s s s 的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足 OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。

以 GPT3-175B 为例,直观对比模型参数与中间激活的显存大小。GPT3 的模型配置如下。假设采用混合精度训练,模型参数和中间激活都采用 float16 数据类型,每个元素占 2 个 bytes。

模型名参数量层数隐藏维度注意力头数
GPT3175B961228896
  • GPT3 的模型参数量为 175B,占用的显存大小为 2 ∗ 175 ∗ 1 0 9 bytes = 350 GB 2*175*10^9 \text{bytes}=350 \text{GB} 2175109bytes=350GB 。GPT3 模型需要占用 350GB 的显存。

  • GPT3 的序列长度 l l l 为 2048 。对比不同的批次大小 b b b 占用的中间激活:

    • l l l = 1 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 275 , 414 , 777 , 856 bytes ≈ 275 GB (34bsd+5bs^2h)∗l=275,414,777,856 \text{bytes}\approx 275 \text{GB} (34bsd+5bs2h)l=275,414,777,856bytes275GB,大约是模型参数显存的 0.79 倍。

    • l l l = 64 时,中间激活占用显存为 ( 34 b s d + 5 b s 2 h ) ∗ l = 17626545782 bytes ≈ 17.6 TB (34bsd+5bs^2h)∗l=17626545782 \text{bytes}\approx 17.6 \text{TB} (34bsd+5bs2h)l=17626545782bytes17.6TB,大约是模型参数显存的 50 倍。

    • l l l = 128 时,中间激活占用显存为, $ (34bsd+5bs^2h)∗l=35253091565568 \text{bytes}\approx 35.3 \text{TB}$ 大约是模型参数显存的 101 倍。

    可以看到随着批次大小 b b b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用 激活重计算 技术来减少中间激活,理论上可以将中间激活显存从 O ( n ) O(n) O(n) 减少到 O ( n ) O(\sqrt{n}) O(n ) ,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。

  • 38
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值