训练过程显存占用分为4部分,模型参数、参数梯度、优化器、
中间变量。
大模型训练llama13B为例的float16为例分析:
模型参数(fp16): fp16占2个byte,是参数量的2倍。2*13GB
参数梯度(fp16):fp16占2个byte,是参数量的2倍。2*13GB
优化器(fp32):fp32占4个byte,是参数量的4倍。4*13GB,但是优化器除了存储权重w还存储其他值,比如adamw还额外存储了动力和方差,所以adamw占用 3 * (4*13B)
中间变量(fp16): 主要由attention和MLP计算得到的中间变量。与输入的batch、序列长度、模型层数 相关。而且这一部分占用的内存更大,比上面的3个部分还大。13B的llama参数,中间变量占45GB的空间。
所以,针对占用的内存,有不同的优化方法。对于,模型参数,可以采样模型并行如megatron。中间变量采用数据并行的deepspeed的zero方法。
具体的计算方法和优化方案,参考以下链接:
【Transformer 基础系列】手推显存占用 - 知乎 (zhihu.com)
PyTorch显存机制分析 - 知乎 (zhihu.com)