本文试图以最清晰的方式手动推导 Transformers 每一步的参数量到显存、计算量问题。理解底层,才能更好的做训练和优化。可能是目前最全的大模型显存优化方案分析。
本文内容包括
(1)模型训练和推理过程中的显存占用
(2)KV cache、中间激活值等显存占用
(3)模型状态显存优化方案: Megatron(3D) + Deepspeed(ZeRO)(更新于2023-09-11)
(4)激活值显存优化方案:重计算 + 3D 并行(更新于2023-08-11)
(5)KV Cache 显存优化方案:MQA 和 GQA(更新于2023-09-11)
前置知识和标记
- 显存占用 = 参数数量 x 该参数精度占用的 bytes 数
- 换算关系:Int8 需1 bytes, fp16 / bf16 数需 2 bytes, fp32 需要 4 bytes
- transformer 模型的层数为 l l