在深度学习中,模型的显存(GPU memory)占用是决定训练和推理效率的关键因素之一。特别是近年来,随着模型参数规模的不断增长,显存的合理分配和优化变得至关重要。本文将深入讲解模型在显存上的主要占用来源,提供详细的计算公式和具体例子,帮助读者理解显存需求的不同来源及如何进行优化。
1. 显存占用的主要来源
在模型训练或推理中,显存占用主要分为以下几部分:
- 模型参数(Model Parameters)
- 前向和反向激活(Activations)
- 梯度(Gradients)
- 优化器状态(Optimizer States)
依次讲解每一部分的显存计算方式及其对整体显存消耗的影响。
1.1 模型参数(Model Parameters)
模型参数是最基础的显存占用部分。通常,一个深度学习模型的参数存储在 GPU 显存中,用于参与前向计算。假设模型有 N
个参数,每个参数的类型是 dtype
(通常为 float32
、float16
或 bfloat16
)。
模型参数显存计算公式:
Parameter Memory = N × Size_per_parameter \text{Parameter Memory} = N \times \text{Size\_per\_parameter} Parameter Memory=N×Size_per_parameter
其中,Size_per_parameter
是每个参数的大小(字节)。常见数据类型对应的大小如下:
float32
: 4 字节float16
: 2 字节bfloat16
: 2 字节
实例:
假设我们有一个模型,总参数量为 10 亿(即 N = 1,000,000,000
),使用 float32
数据类型:
Parameter Memory = 1 , 000 , 000 , 000 × 4 = 4 , 000 , 000 , 000 bytes = 4 , 000 MB = 4 GB \text{Parameter Memory} = 1,000,000,000 \times 4 = 4,000,000,000 \ \text{bytes} = 4,000 \ \text{MB} = 4 \ \text{GB} Parameter Memory=1,000,000,000×4=4,000,000,000 bytes=4,000 MB=4 GB
也就是说,仅模型参数就需要 4GB 的显存。
1.2 前向和反向激活(Activations)
激活(activations)是模型前向计算时的中间输出。这些激活在反向传播时会被用于计算梯度,因此通常在训练过程中,它们需要被暂时存储在显存中。激活显存占用的大小取决于每一层的输出张量(Tensor)大小和数据类型。
对于某一层的输出张量,其激活显存大小可以表示为:
Activation Memory = Batch Size × Output Feature Maps × Height × Width × Size_per_activation \text{Activation Memory} = \text{Batch Size} \times \text{Output Feature Maps} \times \text{Height} \times \text{Width} \times \text{Size\_per\_activation} Activation Memory=Batch Size×Output Feature Maps×Height×Width×Size_per_activation
其中:
Batch Size
:批次大小。Output Feature Maps
:每一层输出的特征图(通道)数量。Height
和Width
:特征图的高和宽。Size_per_activation
:每个激活元素的大小(通常与模型参数相同,如float32
为 4 字节)。
实例:
考虑一个卷积神经网络的卷积层,输入的张量尺寸为 [Batch Size=32, Channels=64, Height=128, Width=128]
,卷积层输出的特征图数量为 128 个,激活数据类型为 float32
:
Activation Memory = 32 × 128 × 128 × 128 × 4 = 2 , 097 , 152 , 000 bytes = 2 , 097 MB ≈ 2 GB \text{Activation Memory} = 32 \times 128 \times 128 \times 128 \times 4 = 2,097,152,000 \ \text{bytes} = 2,097 \ \text{MB} \approx 2 \ \text{GB} Activation Memory=32×128×128×128×4=2,097,152,000 bytes=2,097 MB≈2 GB
所以,仅这一层的激活就占用了大约 2GB 的显存。
1.3 梯度(Gradients)
梯度用于更新模型参数,因此它们也是模型显存占用的一部分。通常,梯度的大小与模型参数的大小相同。如果模型有 N
个参数,并且每个参数的大小是 Size_per_parameter
,那么梯度显存占用可以表示为:
Gradient Memory = N × Size_per_parameter \text{Gradient Memory} = N \times \text{Size\_per\_parameter} Gradient Memory=N×Size_per_parameter
实例:
对于一个拥有 10 亿参数的模型,如果梯度的数据类型是 float32
:
Gradient Memory = 1 , 000 , 000 , 000 × 4 = 4 , 000 MB = 4 GB \text{Gradient Memory} = 1,000,000,000 \times 4 = 4,000 \ \text{MB} = 4 \ \text{GB} Gradient Memory=1,000,000,000×4=4,000 MB=4 GB
1.4 优化器状态(Optimizer States)
在训练过程中,优化器(例如 Adam、SGD)通常需要维护额外的状态,例如动量、方差估计等。这些状态的显存占用通常与模型参数大小成正比。对于 Adam 优化器,除了保存模型参数本身,还需要保存一阶动量和二阶动量,因此其显存占用可以表示为:
Optimizer Memory = N × 2 × Size_per_parameter \text{Optimizer Memory} = N \times 2 \times \text{Size\_per\_parameter} Optimizer Memory=N×2×Size_per_parameter
实例:
假设我们仍然使用一个拥有 10 亿参数的模型,使用 float32
类型:
Optimizer Memory = 1 , 000 , 000 , 000 × 2 × 4 = 8 , 000 MB = 8 GB \text{Optimizer Memory} = 1,000,000,000 \times 2 \times 4 = 8,000 \ \text{MB} = 8 \ \text{GB} Optimizer Memory=1,000,000,000×2×4=8,000 MB=8 GB
1.5 总显存占用
综上,模型训练时的总显存占用可以表示为:
Total Memory = Parameter Memory + Activation Memory + Gradient Memory + Optimizer Memory \text{Total Memory} = \text{Parameter Memory} + \text{Activation Memory} + \text{Gradient Memory} + \text{Optimizer Memory} Total Memory=Parameter Memory+Activation Memory+Gradient Memory+Optimizer Memory
将上面的实例代入:
- 模型参数显存:4GB
- 激活显存:2GB(仅考虑一层的激活)
- 梯度显存:4GB
- 优化器显存:8GB
总显存占用为:
Total Memory = 4 + 2 + 4 + 8 = 18 GB \text{Total Memory} = 4 + 2 + 4 + 8 = 18 \ \text{GB} Total Memory=4+2+4+8=18 GB
注意:如果模型有多层激活张量,实际激活显存可能更大。
2. 显存优化策略
当模型显存占用过高时,可以考虑以下几种优化策略:
2.1 Mixed Precision Training(混合精度训练)
通过将模型参数和计算转换为 float16
,可以有效减小显存占用。与 float32
相比,float16
的显存占用直接减半。
优化后显存:
将上例中的所有 float32
替换为 float16
,总显存占用变为:
Total Memory = 4 × 0.5 + 2 × 0.5 + 4 × 0.5 + 8 × 0.5 = 9 GB \text{Total Memory} = 4 \times 0.5 + 2 \times 0.5 + 4 \times 0.5 + 8 \times 0.5 = 9 \ \text{GB} Total Memory=4×0.5+2×0.5+4×0.5+8×0.5=9 GB
2.2 Gradient Checkpointing(梯度检查点)
通过在前向传播时保存部分激活张量,而非全部保存,可以有效减小激活显存的开销。这样,在反向传播时,模型会重新计算部分激活值,但整体显存占用显著下降。
2.3 模型并行(Model Parallelism)
对于超大模型(如 GPT-3),可以将模型划分到不同的 GPU 上,从而分散显存负担。
2.4 ZeRO 优化
ZeRO 是一种高效的显存优化技术,可以将优化器状态、梯度和参数在多个 GPU 之间进行分割,从而极大地降低显存开销。