激活检查点(Activation Checkpointing)是一种用于优化深度学习模型训练的技术,它可以在内存使用和计算效率之间进行权衡,以在有限的硬件资源下训练更大的模型。
在深度学习模型的训练过程中,前向传播会计算并存储每一层的激活值,这些激活值在后向传播时被用来计算梯度。然而,对于深度很大的模型,这种方式可能会导致内存溢出,因为需要存储大量的激活值。
激活检查点技术通过在前向传播过程中只存储一部分(而不是全部)的激活值来解决这个问题。对于没有存储的激活值,如果在后向传播过程中需要它们,就重新计算这些值。这种方法可以显著减少内存使用,但是会增加一些计算开销,因为需要重新计算一些激活值。
这种技术在训练大型模型(特别是在内存有限的设备上)时非常有用。例如,微软的深度速度优化库 DeepSpeed 就使用了激活检查点技术来实现在单个GPU上训练数十亿参数的模型。
这是微软DeepSpeed库的位置:
具体而言,微软的DeepSpeed在模型训练的AP