tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...

作者:bindog

地址:http://bindog.github.io/

01

背景

前几天看到知乎上的文章 FLOPs与模型推理速度 [1],文中提到一个比较耗时又占显存的pointwise操作x * sigmoid(x),这实际上是 swish activation [2];暂且不提它背后的争议,本文主要想从这个结构入手来优化它的显存占用以及耗时,并讨论更广泛的训练时显存优化技术。

02

反向传播是如何工作的?

要分析清楚swish activation为什么会比较占显存,我们首先需要搞清楚反向传播是如何工作的,或者更进一步说,现有的自动求导框架是如何求出梯度的。 先明确一点,所谓自动求导框架实际上是“半自动”的:它并非直接求出一个复杂函数导数的解析形式,而是通过构建计算图和预先写好的基础函数的求导规则,结合链式求导法则实现的自动求导。 以swish acivation为例进行说明,其表达式为f(x) = x * sigmoid(x),通过简单的数学推导得到其梯度的解析式为f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把这个结果放一边,看看自动求导框架是如何一步步求出这个结果的,画出计算图如下: 5e8bcb985253a81d76a9ca91f9151f8d.png 除了计算图以外,我们还需要定义几个基本函数的求导规则,在这个例子里涉及两个函数,一个是乘法,另一个是sigmoid函数(实际上sigmoid也是由几个基本函数构成的,这里我们将其视为一个整体)
f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x)  # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))

03

显存被谁吃掉了

先说一个结论,在绝大多数神经网络的训练过程中,显存占用的大头是中间结果,也就是所谓的“特征图”。那我们为什么要保留中间结果呢?当然是为了方便求导啊!还是以swish acivation为例,把它放入神经网络来看,x就是前一层输出的中间结果(特征图)
  • 在适用乘法的求导规则时,要求我们要事先保留下中间结果x和sigmoid(x),有人可能会说只保留一个x不就可以了吗?sigmoid(x)可以通过计算得出,注意框架定义的乘法及其求导规则是通用规则,乘法的左右两边完全可能是不相关的两个值,所以必须同时保留下来。
  • 在对sigmoid函数适用求导规则时,需要存下中间结果x。
在不考虑框架自身优化的情况下,显存占用就包括了两个x和一个sigmoid(x),注意x可不是一个单独的数值,而是类似32x32x128这样大小的特征图,考虑到swish acivation在网络中数量庞大,每出现一次就意味着巨大的显存浪费。

04

手动合并OP

那么有没有办法优化呢?当然是可以的,既然我们能用数学公式提前算出swish acivation的梯度,那么直接将其视为一个整体不就好了?无非就是定义一个新的函数和新的求导规则

swish(x) = x * sigmoid(x)# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))

这样一来,计算图变成了下面这个样子:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值