文章目录
神经翻译笔记3扩展b. 自动微分
本文无说明的部分(包括配图)均是翻译/演绎自:
Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2017). Automatic differentiation in machine learning: a survey. Journal of machine learning research, 18(153), 1-43.
不过没有包含若干偏理论的内容
其它引用会单独注明
引言
如前所示,在训练神经网络时,需要计算损失函数对网络参数的梯度,其中会涉及到很多次导数的计算。一般来讲,编程计算导数有四种做法
-
手动微分(manual differentiation),手动推出导数是什么样,然后硬编码。这种做法既耗时也容易出错,还没有灵活性
-
数值微分(numerical differentiation),利用数值代数方法逼近函数的导数值。这种方法存在舍入误差和截断误差,而且扩展性差,在深度学习需要计算百万量级参数的梯度时不适用
-
符号微分(symbolic differentiation)。通常是计算机代数系统采用,例如Mathematica, Maxima和Maple等等。这种方法试图给出给定表达式导数的代数形式,但是会导致表达式爆炸的现象。而且其底层依赖一个封闭的表达式库,给方法求解问题的范畴施加了局限
-
自动微分(automatic differentiation),或者也被称为算法微分(algorithmic differentiation),是本文的主题
自动微分不是什么
自动微分不是数值微分
这里先介绍一下数值微分的计算思想。考虑到导数的定义,如果一个函数 y = f ( x ) y=f(x) y=f(x)在点 x 0 x_0 x0处可导,那么其在该点处的导数为
f ′ ( x 0 ) = lim Δ x → 0 Δ y Δ x = lim Δ x → 0 f ( x 0 + Δ x ) − f ( x 0 ) Δ x f'(x_0) = \lim_{\Delta x\rightarrow 0}\frac{\Delta y}{\Delta x} = \lim_{\Delta x\rightarrow 0} \frac{f(x_0 + \Delta x) - f(x_0)}{\Delta x} f′(x0)=Δx→0limΔxΔy=Δx→0limΔxf(x0+Δx)−f(x0)
因此可以使用导数的定义,使用一个特别小的 Δ x \Delta x Δx(例如 1 0 − 6 10^{-6} 10−6)来计算导数。但是这种做法会产生偏差。假设以 h h h代替 Δ x \Delta x Δx,那么近似计算 f f f在 x x x点处的导数为
f ′ ( x ) ≈ f ( x + h ) − f ( x ) h f'(x) \approx \frac{f(x+h) - f(x)}{h} f′(x)≈hf(x+h)−f(x)
而 f ( x + h ) f(x+h) f(x+h)关于 x x x的泰勒展开为
f ( x + h ) = f ( x ) + h f ′ ( x ) + h 2 2 f ′ ′ ( ξ ) , ξ ∈ ( x , x + h ) f(x+h) = f(x) + hf'(x) + \frac{h^2}{2}f''(\xi),\ \ \ \xi \in (x, x+h) f(x+h)=f(x)+hf′(x)+2h2f′′(ξ), ξ∈(x,x+h)
因此会存在一个 − h 2 f ′ ′ ( ξ ) -\frac{h}{2}f''(\xi) −2hf′′(ξ)的误差,该误差称为截断误差(truncation error)。这种方法称为前向差分法(forward differencing),注意其截断误差是 O ( h ) O(h) O(h)的
对于前向微分法,有一种改进的方法可以提高估计的准确率,称为中心差分法(centered differencing)
f ′ ( x ) ≈ f ( x + h ) − f ( x − h ) 2 h f'(x) \approx \frac{f(x+h) - f(x-h)}{2h} f′(x)≈2hf(x+h)−f(x−h)
(验证内容来自于马里兰大学学院市分校(UMD)数值分析课AMSC466的讲义)下面验证该式有更好的准确度。该式右侧的泰勒展开为
f ( x + h ) = f ( x ) + h f ′ ( x ) + h 2 2 f ′ ′ ( x ) + h 3 6 f ′ ′ ′ ( ξ 1 ) f ( x − h ) = f ( x ) − h f ′ ( x ) + h 2 2 f ′ ′ ( x ) − h 3 6 f ′ ′ ′ ( ξ 1 ) \begin{aligned} f(x+h) &= f(x) + hf'(x) + \frac{h^2}{2}f''(x) + \frac{h^3}{6}f'''(\xi_1) \\ f(x-h) &= f(x) - hf'(x) + \frac{h^2}{2}f''(x) - \frac{h^3}{6}f'''(\xi_1) \end{aligned} f(x+h)f(x−h)=f(x)+hf′(x)+2h2f′′(x)+6h3f′′′(ξ1)=f(x)−hf′(x)+2h2f′′(x)−6h3f′′′(ξ1)
其中 ξ 1 ∈ ( x , x + h ) , ξ 2 ∈ ( x − h , x ) \xi_1 \in (x, x+h), \xi_2 \in (x-h, x) ξ1∈(x,x+h),ξ2∈(x−h,x)。因此
f ′ ( x ) = f ( x + h ) − f ( x − h ) 2 h − h 2 12 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( x 2 ) ] f'(x) = \frac{f(x+h) - f(x-h)}{2h} - \frac{h^2}{12}[f'''(\xi_1) + f'''(x_2)] f′(x)=2hf(x+h)−f(x−h)−12h2[f′′′(ξ1)+f′′′(x2)]
即中心差分法的截断误差是 − h 2 12 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( ξ 2 ) ] -\frac{h^2}{12}[f'''(\xi_1) + f'''(\xi_2)] −12h2[f′′′(ξ1)+f′′′(ξ2)]。假设三阶导数在区间 [ x − h , x + h ] [x-h, x+h] [x−h,x+h]连续,那么由介值定理,存在点 ξ ∈ ( x − h , x + h ) \xi \in (x-h, x+h) ξ∈(x−h,x+h)使得
f ′ ′ ′ ( ξ ) = 1 2 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( ξ 2 ) ] f'''(\xi) = \frac{1}{2}[f'''(\xi_1) + f'''(\xi_2)] f′′′(ξ)=21[f′′′(ξ1)+f′′′(ξ2)]
因此
f ′ ( x ) = f ( x + h ) − f ( x − h ) 2 h − h 2 6 f ′ ′ ′ ( ξ ) f'(x) = \frac{f(x+h) - f(x-h)}{2h}-\frac{h^2}{6}f'''(\xi) f′(x)=2hf(x+h)−f(x−h)−6h2f′′′(ξ)
即中心差分法的截断误差是 O ( h 2 ) O(h^2) O(h2)的。当 h h h很小时,该误差小于前向差分法的误差 ■ \blacksquare ■
**数值微分的缺点是存在截断误差和舍入误差,同时计算太慢。**比较不幸的是,随着 h h h大小的变化,截断误差和舍入误差的变化趋势相反:当 h h h趋近于0时,截断误差也趋近于0,但是舍入误差会慢慢增大,反之相反。下图给出了函数 f ( x ) = 64 x ( 1 − x ) ( 1 − 2 x ) 2 ( 1 − 8 x + 8 x 2 ) 2 f(x) = 64x(1-x)(1-2x)^2(1-8x+8x^2)^2 f(x)=64x(1−x)(1−2x)2(1−8x+8x2)2使用数值微分计算在点 x 0 = 0.2 x_0 = 0.2 x0=0.2的导数时误差随 h h h变化的图像。
此外还需注意一点:当参数是标量时,前向差分法和中心差分法的计算代价相同。不过当参数是向量时,使用中心差分法计算函数 f : R n → R m f: \mathbb{R}^n \rightarrow \mathbb{R}^m f:Rn→Rm的雅可比矩阵需要额外 m n mn mn个计算量。尤其在深度学习领域,对于 n n n维向量,这种 O ( n ) O(n) O(n)的计算量是算法的主要瓶颈,而误差已经不重要了
自动微分不是符号微分
符号微分将输入式子表达为一个表达式树,然后对每个节点使用一些预先设置好的规则做转换。符号微分可以帮助人们更深入地了解问题域的结构,有时候还能给出极值条件的解析解,不过它们会产生指数量级的表达式,因此计算起来效率很低。考虑函数 h ( x ) = f ( x ) g ( x ) h(x) = f(x)g(x) h(x)=f(x)g(x)和微分的乘法法则
d d x ( f ( x ) g ( x ) ) ⇝ ( d d x f ( x ) ) g ( x ) + f ( x ) ( d d x g ( x ) ) \frac{d}{dx}(f(x)g(x)) \rightsquigarrow \left(\frac{d}{dx}f(x)\right)g(x) + f(x)\left(\frac{d}{dx}g(x)\right) dxd(f(x)g(x))⇝(dxdf(x))g(x)+f(x)(dxdg(x))
由于 h h h是两个函数的乘积,因此 h ( x ) h(x) h(x)和 d d x h ( x ) \frac{d}{dx}h(x) dxdh(x)有相同的成分,分别是 f ( x ) f(x) f(x)和