总结自视频(吴恩达大模型入门课):13_02_computational-challenges-of-training-llms_哔哩哔哩_bilibili
1. 模型加载阶段的显存占用
模型参数本身占用的显存应该是最主要的部分。
1.1 模型参数的显存计算
显存占用 ≈ 参数量 × 每个参数占用的字节数。简单来记,如果数据类型是float32,则1b就是占用1G显存。
-
参数量:通常以
B
(十亿)为单位,例如LLaMA-7B模型参量为70亿。 -
数据类型与字节数:
-
float32
: 4字节/参数(默认精度,如PyTorch加载时) -
float16
/bfloat16
: 2字节/参数(混合精度训练或量化加载) -
int8
: 1字节/参数(量化模型,如BitsAndBytes库)
-
示例:
-
加载LLaMA-7B模型:
-
float32
:7B × 4字节 = 28 GB -
float16
:7B × 2字节 = 14 GB -
int8
:7B × 1字节 = 7 GB
-
1.2 实际显存占用的额外因素
训练时,如果都是float32数据类型,每个模型参数会额外增加20字节内存,训练1个参数对应24字节内存。
1.3 快速估算工具
-
使用Hugging Face的
model.num_parameters()
获取参数量。 -
实际加载后,通过
torch.cuda.memory_allocated()
查看精确显存占用。