介绍
梯度检查点(Gradient Checkpointing)是一种深度学习优化技术,它的目的是减少在神经网络训练过程中的内存占用。在训练深度学习模型时,我们需要存储每一层的激活值(即网络层的输出),这样在反向传播时才能计算梯度。但是,如果网络层数非常多,这些激活值会占用大量的内存。
梯度检查点技术通过只在前向传播时保存部分激活值的信息,而在反向传播时重新计算其他激活值,从而减少了内存的使用。具体来说,它在前向传播时使用 torch.no_grad() 来告诉PyTorch不需要计算梯度,因为这些激活值会在反向传播时重新计算。
比喻
假设你在做一道复杂的数学题,通常你需要写下每一步的计算结果,以便在检查错误时可以追溯回去。但如果你确信大部分计算都是正确的,只是在最后几步可能出错,那么你就可以只保存最后几步的结果,然后在检查时重新计算前面的步骤。这样,你就可以节省纸张(在神经网络中就是内存)。
例子
假设有一个深度神经网络,用于图像分类任务。网络有10层,每层都需要保存激活值以便反向传播时计算梯度。如果没有使用梯度检查点,你需要在内存中保存所有这些激活值。
现在,使用梯度检查点,你可以在前向传播时只保存第1层和第10层的激活值,而在反向传播时重新计算第2层到第9层的激活值。这样,你就大大减少了需要保存的激活值数量,从而节省了内存。
import torch
from torch.utils.checkpoint import checkpoint
def forward_model(model, input, checkpointing=True)<