作者:bindog
地址:http://bindog.github.io/
01
背景
前几天看到知乎上的文章 FLOPs与模型推理速度 [1],文中提到一个比较耗时又占显存的pointwise操作x * sigmoid(x),这实际上是 swish activation [2];暂且不提它背后的争议,本文主要想从这个结构入手来优化它的显存占用以及耗时,并讨论更广泛的训练时显存优化技术。02
反向传播是如何工作的?
要分析清楚swish activation为什么会比较占显存,我们首先需要搞清楚反向传播是如何工作的,或者更进一步说,现有的自动求导框架是如何求出梯度的。 先明确一点,所谓自动求导框架实际上是“半自动”的:它并非直接求出一个复杂函数导数的解析形式,而是通过构建计算图和预先写好的基础函数的求导规则,结合链式求导法则实现的自动求导。 以swish acivation为例进行说明,其表达式为f(x) = x * sigmoid(x),通过简单的数学推导得到其梯度的解析式为f'(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x));先把这个结果放一边,看看自动求导框架是如何一步步求出这个结果的,画出计算图如下: 除了计算图以外,我们还需要定义几个基本函数的求导规则,在这个例子里涉及两个函数,一个是乘法,另一个是sigmoid函数(实际上sigmoid也是由几个基本函数构成的,这里我们将其视为一个整体)f(x, y) = x * y# gradient for x: y# gradient for y: xg(x) = sigmoid(x) # 1 / (1 + exp(-x))# gradient for x: sigmoid(x) * (1 - sigmoid(x))
03
显存被谁吃掉了
先说一个结论,在绝大多数神经网络的训练过程中,显存占用的大头是中间结果,也就是所谓的“特征图”。那我们为什么要保留中间结果呢?当然是为了方便求导啊!还是以swish acivation为例,把它放入神经网络来看,x就是前一层输出的中间结果(特征图)- 在适用乘法的求导规则时,要求我们要事先保留下中间结果x和sigmoid(x),有人可能会说只保留一个x不就可以了吗?sigmoid(x)可以通过计算得出,注意框架定义的乘法及其求导规则是通用规则,乘法的左右两边完全可能是不相关的两个值,所以必须同时保留下来。
- 在对sigmoid函数适用求导规则时,需要存下中间结果x。
04
手动合并OP
那么有没有办法优化呢?当然是可以的,既然我们能用数学公式提前算出swish acivation的梯度,那么直接将其视为一个整体不就好了?无非就是定义一个新的函数和新的求导规则
swish(x) = x * sigmoid(x)
# gradient for x: sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
这样一来,计算图变成了下面这个样子: