大模型训练ZeRO内存优化原理详解


0. 引言

Zero Redundancy Optimizer (ZeRO),主要目标是减少内存使用并加速大规模模型的训练过程。它通过在多个 GPU 或者节点之间分散模型的状态(如梯度和参数)来实现这一目标。这种分散减少了每个计算节点上存储的冗余数据量,从而降低了内存占用。

论文:《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》

1. GPU 内存分布

1.1 模型状态

模型状态包括:
(1)优化器状态(Optimizer States),例如使用 Adam 优化器时的动量和梯度方差
(2)梯度(Gradients)
(3)参数( Parameters)

上面的模型状态通常占据了大部分的内存,在混合精度训练中,还需要额外的内存来存储 fp32 的参数和优化器状态。

比如 GPT-2(具有 1.5B 参数)模型,模型状态的保存要求至少 24 GB 的内存。

1.2 剩余内存

除了模型状态外,剩余的内存包含:
(1)激活内存。用于正向传播以执行反向传播的存储,可以通过激活检查点(checkpointing)来减少,但会提升计算量;
(2)临时缓冲区。用于存储中间结果,其大小随着模型大小的增加而增加
(3)不可用的碎片化内存

以上统称为除保存模型状态之外的剩余内存。

2. ZeRO 优化

2.1 ZeRO-DP 优化

ZeRO-DP(ZeRO 数据并行),优化三个阶段的内存消耗情况:
在这里插入图片描述
Ψ \Psi Ψ 为模型大小(参数个数), K K K 为优化器状态的内存乘数, N d N_d Nd 为数据并行度,可以理解为 GPU 卡数。

在本例中,假设基于 Adam 优化器的混合精度训练,模型大小为 7.5B, N d = 64 N_d=64 Nd=64 K = 12 K = 12 K=12

下面分别介绍 ZeRO-DP 优化的三个阶段的具体情况。

2.1.1 ZeRO-Stage1: 优化器状态划分

(1) P o s P_{os} Pos(Optimizer State Partitioning,优化器状态划分)
ZeRO 通过将优化器状态划分为 N d N_d Nd 个数据并行进程,每个进程仅存储、更新其对应分区的优化器状态,即整体优化器状态的 1 N d \frac1{N_d} Nd1,从而减少了每个设备上所需的内存量。在每个训练步骤结束时,再收集每一个进程的结果,以获取整体更新后的状态参数。

(2)ZeRO-Stage1 内存优化后的结果,主要针对优化器状态(请参考上图):
( 2 + 2 ) Ψ + K ∗ Ψ N d (2+2)\Psi + \frac{K*\Psi}{N_d} (2+2)Ψ+NdKΨ
可见,优化器状态内存在原始基础上有一个 N d N_d Nd 的除数。

(3)举例
在 7.5B 的模型上,标准的情况下要求 120GB 的内存,但是使用 P o s P_{os} Pos 后, N d = 64 N_d=64 Nd=64 的情况下,仅要求 31.4 GB 的内存。

而当 N d N_d Nd 非常大时,内存消耗:
( 2 + 2 ) Ψ + K ∗ Ψ N d ≈ 4 Ψ (2+2)\Psi + \frac{K*\Psi}{N_d} \approx4\Psi (2+2)Ψ+NdKΨ
与原始的比例:
4 4 + K \frac{4}{4+K} 4+K4
K = 12 K=12 K=12 时,为 1 4 \frac14 41,即内存是原始的 1 4 \frac14 41

2.1.2 ZeRO-Stage2: 优化器状态+梯度划分

(1) P g P_g Pg(Gradient Partitioning,梯度划分)
每个数据并行进程只存储和更新其对应的参数分区所需的梯度,减少了存储全部梯度的内存需求

(2) P o s + g P_{os+g} Pos+g,优化器状态+梯度划分
即在 ZeRO-Stage1 的 P o s P_{os} Pos 基础上,增加了 P g P_g Pg,则是 ZeRO-Stage2

(3)ZeRO-Stage2 内存优化后的结果,主要针对优化器状态+梯度(请参考上图):
( 2 + 2 + K ) ∗ Ψ N d \frac{(2+2+K)*\Psi}{N_d} Nd(2+2+K)Ψ

(4)举例
在 7.5B 的模型上,标准的情况下要求 120GB 的内存,但是使用 P o s + g P_{os+g} Pos+g 后, N d = 64 N_d=64 Nd=64 的情况下,仅要求 16.6 GB 的内存。

而当 N d N_d Nd 非常大时,内存消耗:
( 2 + 2 + K ) ∗ Ψ N d ≈ 0 \frac{(2+2+K)*\Psi}{N_d} \approx0 Nd(2+2+K)Ψ0
这意味着,理论情况下,当设备足够多时,可以训练任意大的

2.1.3 ZeRO-Stage3: 优化器状态+梯度+参数划分

(1) P p P_p Pp( Parameter Partitioning,梯度划分)
类似于优化器状态和梯度的划分,每个进程只存储其参数分区的参数,在需要时通过广播从其他进程接收非本分区的参数。

(2) P o s + g + p P_{os+g+p} Pos+g+p,优化器状态+梯度划分
即在 ZeRO-Stage2 的 P o s + g P_{os+g} Pos+g 基础上,增加了 P p P_p Pp,则是 ZeRO-Stage3

(3)ZeRO-Stage3 内存优化后的结果,主要针对优化器状态+梯度+参数(请参考上图):
2 Ψ + ( 2 + K ) ∗ Ψ N d 2\Psi + \frac{(2+K)*\Psi}{N_d} +Nd(2+K)Ψ

(4)举例
在 7.5B 的模型上,标准的情况下要求 120GB 的内存,但是使用 P o s + g + p P_{os+g+p} Pos+g+p 后, N d = 64 N_d=64 Nd=64 的情况下,仅要求 1.9 GB 的内存。

而当 N d N_d Nd 非常大时,内存消耗:
2 Ψ + ( 2 + K ) ∗ Ψ N d ≈ 2 Ψ 2\Psi + \frac{(2+K)*\Psi}{N_d} \approx2\Psi +Nd(2+K)Ψ
与原始的比例:
2 4 + K \frac{2}{4+K} 4+K2
K = 12 K=12 K=12 时,为 1 8 \frac18 81,即内存是原始的 1 8 \frac18 81

2.2 ZeRO-R 优化

2.2.1 减少激活内存

(1) P a P_a Pa( Partitioned Activation Checkpointing,划分激活检查点)
ZeRO-R 通过 P a P_a Pa 操作来减少因模型并行化(MP)导致的激活内存冗余。在正向传播过程中,每一层的输入激活被分割并存储在所有模型并行进程中,仅存储分区的激活检查点,而不是复制副本。ZeRO-R 使用 all-gather 操作在反向传播需要时重新生成激活的复制副本。

(2) P a + c p u P_{a+cpu} Pa+cpu
对于非常大的模型,ZeRO-R 可以将分割的激活检查点卸载到 CPU 内存中,几乎将激活内存开销降至零,但是要额外的通信成本。

(3)举例
例如,对于一个 100B 参数的模型,如果每个 Transformer 层仅检查点一个激活,那么仅存储激活检查点就需要一个 GPU 约 33GB 的内存。但是,使用 ZeRO-R 中的 P a P_a Pa 优化,可以将其降低到每 GPU 约 2GB。此外,这 2GB 可以卸载到 CPU 上,将激活的内存占用减少到几乎为零。

2.2.2 管理临时缓冲区

ZeRO-R 通过使用固定大小的缓冲区来避免临时缓冲区随着模型大小增加而膨胀,同时确保缓冲区足够大以保持效率。

2.2.3 管理碎片化内存

内存碎片化是由于短期和长期存活内存对象的交错导致的。ZeRO-R 执行即时内存碎片整理,通过将激活检查点和梯度移动到预先分配的连续内存缓冲区中,不仅增加了内存的可用性,还通过减少内存分配器寻找连续内存块的时间来提高效率。

3. ZeRO 通讯分析

3.1 ZeRO-DP通讯分析

3.1.1 P o s + g P_{os+g} Pos+g的通讯量

使用梯度分区,每个进程只存储更新其相应参数分区所需的梯度部分。
(1)ZeRO 只需要在梯度上进行分散缩减操作,从而产生 Ψ \Psi Ψ 的通信量。
(2)在每个进程更新其负责的参数分区后,执行全收集以从所有数据并行进程中收集所有更新的参数。这也会产生 Ψ \Psi Ψ 的通信量。
(3)因此,每个训练步骤的总通信量为 Ψ + Ψ = 2 Ψ \Psi + \Psi = 2\Psi Ψ+Ψ= 与标准 DP 情况完全相同。

3.1.2 P o s + g + p P_{os+g+p} Pos+g+p的通讯量

加入 P p P_p Pp 参数划分之后,ZeRO-DP 的通信量最多增加到标准 DP的1.5倍,即 3 Ψ 3\Psi 。这是因为在前向传播和反向传播中,参数需要在进程间进行广播和收集。尽管如此, P p P_p Pp 阶段进一步将内存占用减少,且减少程度与数据并行度 N d N_d Nd 成线性关系。

3.2 ZeRO-R通讯分析

P a P_a Pa 在 ZeRO-R 中的通信开销与传统的 MP 模型并行方法相比,增加量通常不到10%。

但是由于 ZeRO-R 还提供了将激活分区卸载到 CPU 内存的选项 P a + c p u P_{a+cpu} Pa+cpu,这可以在保持效率不降低太多的同时,进一步减少 GPU 上的内存需求。

4. 参考

[1] https://arxiv.org/abs/1910.02054


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

在这里插入图片描述

  • 20
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SmallerFL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值