激活检查点(Activation Checkpointing)是一种用于优化深度学习模型训练的技术,它可以在内存使用和计算效率之间进行权衡,以在有限的硬件资源下训练更大的模型。
在深度学习模型的训练过程中,前向传播会计算并存储每一层的激活值,这些激活值在后向传播时被用来计算梯度。然而,对于深度很大的模型,这种方式可能会导致内存溢出,因为需要存储大量的激活值。
激活检查点技术通过在前向传播过程中只存储一部分(而不是全部)的激活值来解决这个问题。对于没有存储的激活值,如果在后向传播过程中需要它们,就重新计算这些值。这种方法可以显著减少内存使用,但是会增加一些计算开销,因为需要重新计算一些激活值。
这种技术在训练大型模型(特别是在内存有限的设备上)时非常有用。例如,微软的深度速度优化库 DeepSpeed 就使用了激活检查点技术来实现在单个GPU上训练数十亿参数的模型。
这是微软DeepSpeed库的位置:
具体而言,微软的DeepSpeed在模型训练的API中使用了此项技术,代码在这个链接中:
具体实现的关键是一个名为CheckpointFunction
的类
该类继承自torch.autograd.Function
并定义了两个静态方法:forward
和backward
,分别对应于模型的前向传播和反向传播过程。
在forward
方法中,它首先保存了当前的随机数生成器(RNG)状态,然后运行前向传播函数,并将结果输出。如果启用了激活分区(Partition Activations)或CPU检查点(CPU Checkpoint),则会将激活状态保存到CPU内存中,以减少GPU内存的使用。
在backward
方法中,它首先恢复在前向传播过程中保存的RNG状态,然后运行前向传播函数以重新计算激活状态,并使用保存的梯度进行反向传播。如果启用了激活分区或CPU检查点,它会将激活状态从CPU内存移回到GPU内存,以进行反向传播计算。
此外,这个类还包含了一些优化和调试功能,如同步CUDA操作,计时器用于性能分析,以及内存使用情况的输出。