一句话总结
xB的大模型,训练的显存占用约为16x GB, 使用lora大概占用4xGB。(默认全精度float32存储)
推理的显存占用约为4xGB
显存占用分析
主要来源
模型参数、梯度、优化器、激活值
模型参数
与类型有关,fp32, 4个字节,fp16 2个字节,int8 1个字节
梯度
与参数量保持一致。
优化器
以adam为例,需要计算上一时刻的m与v, 所以是参数量的2倍。
激活值
y=wx+b, y就是激活值,与模型结构相关,不同模型结构激活值数量不同。
对于 l层transformer模型,中间激活占用的显存大小可以近似为
与batch以及输入序列s成正比。通常会采用激活重计算技术来减少中间激活,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。
训练
所以我们最简单的估算,忽略激活值。以7B模型fp32为例子,7Bx4x(1+1+2)=7Bx16的显存
可以使用lora、减少梯度和优化器,大幅降低显存
可以使用deepspeed优化显存
batch增大显存增加主要原因是增加了激活值的显存!!!!
推理
推理只有参数与激活占用。以7B模型fp32为例子,7Bx4=28B的显存
可以使用fp16,int8推理降低显存占用。
参考:
分析transformer模型的参数量、计算量、中间激活、KV cache - 回旋托马斯x的文章 - 知乎
https://zhuanlan.zhihu.com/p/624740065
大模型训练显存估算 - phynlp的文章 - 知乎
https://zhuanlan.zhihu.com/p/680434161
多大的显存能够满足主流大模型的训练开销? - 小满哥的回答 - 知乎
https://www.zhihu.com/question/636721650/answer/3386092714