全局 batch size = 卡数 × 每卡 batch size × 梯度累积步数
是LLM训练非常关键的一个概念,尤其是在分布式训练(如使用多张 GPU 或 TPU)时,理解这个公式非常重要。下面我来逐个解释每一项的含义,并说明为什么这个公式成立。
🧠 一、为什么要有“全局 batch size”?
在训练深度学习模型时,batch size(批大小) 是影响模型训练速度、内存占用、收敛性等的重要超参数。
但在使用多卡训练(如 PyTorch 的 DDP)或梯度累积时,单张卡上的 batch size 不再是完整的 batch size,所以引入 “全局 batch size” 来表示一次参数更新所使用的总样本数。
🧮 二、公式详解
全局 batch size = 卡数 × per device batch size × gradient accumulation steps
1. 卡数(Number of GPUs / Devices)
-
含义:你用于训练的 GPU 或计算设备的数量。
-
作用:每张卡独立计算一部分数据的梯度,最后将梯度汇总。
2. 每卡 batch size(Per device batch size)
-
含义:每张卡上每次前向/反向传播使用的样本数。
-
举例:如果你有 4 张卡,每张卡处理 8 个样本,那么一次前向/反向传播总共处理 4 × 8 = 32 个样本。
-
限制因素:显存大小。每张卡的 batch size 不能太大,否则会 OOM(显存溢出)。
3. 梯度累积步数(Gradient Accumulation Steps)
-
含义:在更新参数前,累积多少次梯度。
-
作用:当显存不足时,可以用较小的 batch size 多次反向传播后再更新参数,模拟更大的 batch size。
-
举例:如果你的device是1, per device batch size 是 8,梯度累积步数是 4,那等价于一次使用 1 × 8 × 4 = 32 个样本更新参数。
-
我们以
gradient_accumulation_steps = 4
为例:
复制代码1Step 1: 前向传播 → 反向传播 → 梯度加到缓存中(不更新参数)
Step 2: 前向传播 → 反向传播 → 梯度加到缓存中(不更新参数)
Step 3: 前向传播 → 反向传播 → 梯度加到缓存中(不更新参数)
Step 4: 前向传播 → 反向传播 → 梯度加到缓存中(不更新参数)
Step 5: 所有梯度加起来 → 更新一次参数
也就是说,4 次小 batch 的梯度加在一起,才做一次参数更新。
-
梯度累积步数越大越好吗?
梯度累积步数不是越大越好,因为虽然它能节省显存、模拟大 batch,但会带来训练变慢、收敛不稳定、优化器行为变化等问题。
🧾 三、公式直观理解
项 | 数值 | 说明 |
卡数 | 4 | 使用 4 张 GPU |
per device batch size | 8 | 每张卡每次处理 8 个样本 |
gradient accumulation steps | 2 | 每累积 2 次梯度才更新一次参数 |
那么:
全局 batch size = 4(卡数) × 8(每卡) × 2(累积步数) = 64
也就是说,每更新一次模型参数,总共使用了 64 个样本的信息。
📌 四、为什么这个公式重要?
1. 影响训练效果
-
更大的全局 batch size 通常可以使用更大的学习率。
-
batch size 太小可能导致训练不稳定,太大可能导致泛化能力下降。
2. 影响训练速度和资源利用
-
每卡 batch size 太小 → 显卡利用率低。
-
梯度累积步数太多 → 占用更多内存,训练变慢。
3. 影响学习率缩放策略
-
当你增加全局 batch size 时,通常需要按比例调整学习率(如线性缩放规则)。
🧪 五、举个例子(对比不同设置)
卡数 | per device batch size | gradient acc steps | 全局 batch size |
1 | 32 | 1 | 32 |
4 | 8 | 1 | 32 |
2 | 8 | 2 | 32 |
1 | 8 | 4 | 32 |
这四种设置的全局 batch size 都是 32,但显存占用和训练速度不同。
📝 六、如何设置这些参数?
1. 先确定每卡 batch size
-
根据显存限制,找到单卡最大能跑的 batch size。
2. 再根据硬件资源决定卡数
-
如果你有 4 张卡,可以跑 4 × per_device_batch_size。
3. 最后决定是否使用梯度累积
-
如果显存不够,可以降低 per_device_batch_size,增加 gradient accumulation steps。
✅ 七、总结
参数 | 含义 | 公式中的作用 |
卡数 | GPU 数量 | 每卡处理一部分 batch |
per device batch size | 每卡每次处理的样本数 | 决定单卡负载 |
gradient accumulation steps | 多少次反向传播后才更新参数 | 累积梯度,模拟更大的 batch size |
全局 batch size | 实际用于更新模型的总样本数 |
|