神经翻译笔记3扩展b. 自动微分

本文介绍了自动微分的概念,区分了它与数值微分和符号微分的区别。重点讲解了自动微分的前向模式和后向模式,分析了它们在机器学习中的应用,特别是对神经网络训练的影响。前向模式适用于输入维度较小的情况,而后向模式则在输出维度较小但输入维度大的场景中更为高效。自动微分是现代深度学习框架如TensorFlow的基础,用于计算梯度,推动模型优化。
摘要由CSDN通过智能技术生成

神经翻译笔记3扩展b. 自动微分

本文无说明的部分(包括配图)均是翻译/演绎自:

Baydin, A. G., Pearlmutter, B. A., Radul, A. A., & Siskind, J. M. (2017). Automatic differentiation in machine learning: a survey. Journal of machine learning research, 18(153), 1-43.

不过没有包含若干偏理论的内容

其它引用会单独注明


引言

如前所示,在训练神经网络时,需要计算损失函数对网络参数的梯度,其中会涉及到很多次导数的计算。一般来讲,编程计算导数有四种做法

  • 手动微分(manual differentiation),手动推出导数是什么样,然后硬编码。这种做法既耗时也容易出错,还没有灵活性

  • 数值微分(numerical differentiation),利用数值代数方法逼近函数的导数值。这种方法存在舍入误差和截断误差,而且扩展性差,在深度学习需要计算百万量级参数的梯度时不适用

  • 符号微分(symbolic differentiation)。通常是计算机代数系统采用,例如Mathematica, Maxima和Maple等等。这种方法试图给出给定表达式导数的代数形式,但是会导致表达式爆炸的现象。而且其底层依赖一个封闭的表达式库,给方法求解问题的范畴施加了局限

  • 自动微分(automatic differentiation),或者也被称为算法微分(algorithmic differentiation),是本文的主题

自动微分不是什么

自动微分不是数值微分

这里先介绍一下数值微分的计算思想。考虑到导数的定义,如果一个函数 y = f ( x ) y=f(x) y=f(x)在点 x 0 x_0 x0处可导,那么其在该点处的导数为
f ′ ( x 0 ) = lim ⁡ Δ x → 0 Δ y Δ x = lim ⁡ Δ x → 0 f ( x 0 + Δ x ) − f ( x 0 ) Δ x f'(x_0) = \lim_{\Delta x\rightarrow 0}\frac{\Delta y}{\Delta x} = \lim_{\Delta x\rightarrow 0} \frac{f(x_0 + \Delta x) - f(x_0)}{\Delta x} f(x0)=Δx0limΔxΔy=Δx0limΔxf(x0+Δx)f(x0)
因此可以使用导数的定义,使用一个特别小的 Δ x \Delta x Δx(例如 1 0 − 6 10^{-6} 106)来计算导数。但是这种做法会产生偏差。假设以 h h h代替 Δ x \Delta x Δx,那么近似计算 f f f x x x点处的导数为
f ′ ( x ) ≈ f ( x + h ) − f ( x ) h f'(x) \approx \frac{f(x+h) - f(x)}{h} f(x)hf(x+h)f(x)
f ( x + h ) f(x+h) f(x+h)关于 x x x的泰勒展开为
f ( x + h ) = f ( x ) + h f ′ ( x ) + h 2 2 f ′ ′ ( ξ ) ,     ξ ∈ ( x , x + h ) f(x+h) = f(x) + hf'(x) + \frac{h^2}{2}f''(\xi),\ \ \ \xi \in (x, x+h) f(x+h)=f(x)+hf(x)+2h2f(ξ),   ξ(x,x+h)
因此会存在一个 − h 2 f ′ ′ ( ξ ) -\frac{h}{2}f''(\xi) 2hf(ξ)的误差,该误差称为截断误差(truncation error)。这种方法称为前向差分法(forward differencing),注意其截断误差是 O ( h ) O(h) O(h)
对于前向微分法,有一种改进的方法可以提高估计的准确率,称为中心差分法(centered differencing)
f ′ ( x ) ≈ f ( x + h ) − f ( x − h ) 2 h f'(x) \approx \frac{f(x+h) - f(x-h)}{2h} f(x)2hf(x+h)f(xh)
(验证内容来自于马里兰大学学院市分校(UMD)数值分析课AMSC466的讲义)下面验证该式有更好的准确度。该式右侧的泰勒展开为
f ( x + h ) = f ( x ) + h f ′ ( x ) + h 2 2 f ′ ′ ( x ) + h 3 6 f ′ ′ ′ ( ξ 1 ) f ( x − h ) = f ( x ) − h f ′ ( x ) + h 2 2 f ′ ′ ( x ) − h 3 6 f ′ ′ ′ ( ξ 1 ) \begin{aligned} f(x+h) &= f(x) + hf'(x) + \frac{h^2}{2}f''(x) + \frac{h^3}{6}f'''(\xi_1) \\ f(x-h) &= f(x) - hf'(x) + \frac{h^2}{2}f''(x) - \frac{h^3}{6}f'''(\xi_1) \end{aligned} f(x+h)f(xh)=f(x)+hf(x)+2h2f(x)+6h3f(ξ1)=f(x)hf(x)+2h2f(x)6h3f(ξ1)
其中 ξ 1 ∈ ( x , x + h ) , ξ 2 ∈ ( x − h , x ) \xi_1 \in (x, x+h), \xi_2 \in (x-h, x) ξ1(x,x+h),ξ2(xh,x)。因此
f ′ ( x ) = f ( x + h ) − f ( x − h ) 2 h − h 2 12 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( x 2 ) ] f'(x) = \frac{f(x+h) - f(x-h)}{2h} - \frac{h^2}{12}[f'''(\xi_1) + f'''(x_2)] f(x)=2hf(x+h)f(xh)12h2[f(ξ1)+f(x2)]
即中心差分法的截断误差是 − h 2 12 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( ξ 2 ) ] -\frac{h^2}{12}[f'''(\xi_1) + f'''(\xi_2)] 12h2[f(ξ1)+f(ξ2)]。假设三阶导数在区间 [ x − h , x + h ] [x-h, x+h] [xh,x+h]连续,那么由介值定理,存在点 ξ ∈ ( x − h , x + h ) \xi \in (x-h, x+h) ξ(xh,x+h)使得
f ′ ′ ′ ( ξ ) = 1 2 [ f ′ ′ ′ ( ξ 1 ) + f ′ ′ ′ ( ξ 2 ) ] f'''(\xi) = \frac{1}{2}[f'''(\xi_1) + f'''(\xi_2)] f(ξ)=21[f(ξ1)+f(ξ2)]
因此
f ′ ( x ) = f ( x + h ) − f ( x − h ) 2 h − h 2 6 f ′ ′ ′ ( ξ ) f'(x) = \frac{f(x+h) - f(x-h)}{2h}-\frac{h^2}{6}f'''(\xi) f(x)=2hf(x+h)f(xh)6h2f(ξ)
即中心差分法的截断误差是 O ( h 2 ) O(h^2) O(h2)的。当 h h h很小时,该误差小于前向差分法的误差 ■ \blacksquare

**数值微分的缺点是存在截断误差和舍入误差,同时计算太慢。**比较不幸的是,随着 h h h大小的变化,截断误差和舍入误差的变化趋势相反:当 h h h趋近于0时,截断误差也趋近于0,但是舍入误差会慢慢增大,反之相反。下图给出了函数 f ( x ) = 64 x ( 1 − x ) ( 1 − 2 x ) 2 ( 1 − 8 x + 8 x 2 ) 2 f(x) = 64x(1-x)(1-2x)^2(1-8x+8x^2)^2 f(x)=64x(1x)(12x)2(18x+8x2)2使用数值微分计算在点 x 0 = 0.2 x_0 = 0.2 x0=0.2的导数时误差随 h h h变化的图像。

数值微分误差与h的关系

此外还需注意一点:当参数是标量时,前向差分法和中心差分法的计算代价相同。不过当参数是向量时,使用中心差分法计算函数 f : R n → R m f: \mathbb{R}^n \rightarrow \mathbb{R}^m f:RnRm的雅可比矩阵需要额外 m n mn mn个计算量。尤其在深度学习领域,对于 n n n维向量,这种 O ( n ) O(n) O(n)的计算量是算法的主要瓶颈,而误差已经不重要了

自动微分不是符号微分

符号微分将输入式子表达为一个表达式树,然后对每个节点使用一些预先设置好的规则做转换。符号微分可以帮助人们更深入地了解问题域的结构,有时候还能给出极值条件的解析解,不过它们会产生指数量级的表达式,因此计算起来效率很低。考虑函数 h ( x ) = f ( x ) g ( x ) h(x) = f(x)g(x) h(x)=f(x)g(x)和微分的乘法法则
d d x ( f ( x ) g ( x ) ) ⇝ ( d d x f ( x ) ) g ( x ) + f ( x ) ( d d x g ( x ) ) \frac{d}{dx}(f(x)g(x)) \rightsquigarrow \left(\frac{d}{dx}f(x)\right)g(x) + f(x)\left(\frac{d}{dx}g(x)\right) dxd(f(x)g(x))(dxdf(x))g(x)+f(x)(dxdg(x))
由于 h h h是两个函数的乘积,因此 h ( x ) h(x) h(x) d d x h ( x ) \frac{d}{dx}h(x) dxdh(x)有相同的成分,分别是 f ( x ) f(x) f(x)

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值