大语言模型的工程技巧(四)——梯度检查点

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。

本文将讨论如何利用梯度检查点算法来减少模型在训练时候(更准确地说是运行反向传播算法时)的内存开支。这在训练超大规模的模型时会用到。

关于其他的工程技巧可以参考:

关于大语言模型的讨论请参考:

一、标准反向传播

根据梯度的定义,变量的梯度与其本身的值密切相关。因此,要想得到某个变量的梯度,必须先知道这个变量的值。这也是为什么在进行反向传播算法之前,需要先对计算图进行向前传播,并记录每个节点的计算结果,如图1左侧部分所示。这样在计算节点的梯度时,可以利用这些事先缓存的结果,直接启动反向传播过程,从而得到梯度,如图1中的节点d所示。这种方法也被称为标准反向传播。这种方式能够确保梯度计算以最高效的方式进行。

图1

图1

二、内存极简算法

然而,采用标准反向传播算法会造成较大的内存开销。为了在计算过程中尽可能地压缩内存使用,可以采用一种以时间换空间的方法。在这种算法中,一旦向前传播完成,仅会保留顶点的计算结果,而中间节点的结果会被清空(叶子节点的值会保留)。在反向传播遇到中间计算节点没有缓存时,则重新触发向前传播,以获取所需节点的结果。这就是内存极简的反向传播算法。以节点d为例,为了计算其梯度,需要首先从节点a开始重新触发向前传播直到节点d,并缓存计算结果。然后使用这个缓存的结果以及节点e的梯度,计算出节点d的梯度。对于其他节点,也采用类似的步骤计算梯度。通过这种方式,在完成反向传播的同时,节省了内存开销。以图1为例,内存极简算法只需要3个存储空间,而标准算法需要5个存储空间。

三、梯度检查点

尽管内存极简算法在降低内存开销方面取得了显著成果,但它涉及大量的重复计算,运行时间相对较长。为了在内存使用和运行时间之间取得平衡,下面引入梯度检查点(Gradient Checkpoint)。这一算法的核心思想是选择一些中间节点作为存储点,以便在再次触发向前传播时,以这些存储点作为起点开始传播,避免从头开始重复计算。这种方式在一定程度上减少重复计算,从而提高运行效率。需要注意的是,由于需要存储额外的中间结果,梯度检查点会稍微增加一些内存开销。

关于梯度检查点算法,PyTorch中已经提供了便捷的封装函数,即torch.utils.checkpoint。这个工具能够帮助我们更方便地应用梯度检查点算法,以平衡内存开锁和运行时间。更多细节请参考这个链接

  • 8
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值