摘要
自动微分(Automatic Differentiation,AD)是一种对计算机程序进行高效准确求导的技术,一直被广泛应用于计算流体力学、大气科学、工业设计仿真优化等领域。而近年来,机器学习技术的兴起也驱动着对自动微分技术的研究进入一个新的阶段。随着自动微分和其他微分技术研究的深入,其与编程语言、计算框架、编译器等领域的联系愈发紧密,从而衍生扩展出更通用的可微编程概念。本文章将对自动微分技术和可微编程的研究和发展进行概括综述。
本文章将分为三部分,主要框架如下:
- 常见计算机程序求导方法介绍
- 业内自动微分和可微编程方案介绍
- 自动微分和可微编程待解决的问题和展望
常见计算机程序求导方法介绍
对计算机程序求导的方法可以归纳为以下四种:
- 手工求导并编写对应的结果程序(Manual Differentiation)
- 通过有限差分近似方法完成求导,称为数值微分(Numerical Differentiation)
- 基于数学规则和程序表达式变换完成求导,称为符号微分(Symbolic Differentiation)
- 介于数值微分和符号微分之间的一种求导方法,也是本文介绍的重点,称为自动微分(Automatic Differentiation)
除去手工求导不在我们考虑的范畴内,下面我们主要介绍其余三种求导方法并分析它们的利弊。
数值微分
数值微分1(图 1 右下角所示)使用差分近似方法完成,其本质是根据导数的定义推导而来。
f ′ ( x ) = l i m h → 0 f ( x + h ) − f ( x ) h f'(x)=lim_{h \to 0}{\frac{f(x+h)-f(x)}{h}} f′(x)=limh→0hf(x+h)−f(x)
观察导数的定义容易想到,当 h h h 充分小时,可以用差商 f ( x + h ) − f ( x ) h \frac{f(x+h)-f(x)}{h} hf(x+h)−f(x) 近似导数结果。而近似的一部分误差(截断误差,Truncation Error2)可以由泰勒公式中的二阶及二阶后的所有余项给出:
f ( x ± h ) = f ( x ) ± h f ′ ( x ) + h 2 2 ! f ′ ′ ( x ) ± h 3 3 ! f ′ ′ ′ ( x ) + . . . + ( ± h ) n n ! f ( n ) ( x ) f(x \pm h) = f(x) \pm hf'(x) + \frac{h^2}{2!}f''(x) \pm \frac{h^3}{3!}f'''(x) + ... + {(\pm h)^n}{n!}f^{(n)}(x) f(x±h)=f(x)±hf′(x)+2!h2f′′(x)±3!h3f′′′(x)+...+(±h)nn!f(n)(x)
因此数值微分中常用的三种计算方式及其对应的截断误差可以归纳如下:
- 向前差商(Forward Difference)
∂ f ( x ) ∂ x ≈ f ( x + h ) − f ( x ) h \frac{\partial f(x)}{\partial x} \approx \frac{f(x+h)-f(x)}{h} ∂x∂f(x)≈hf(x+h)−f(x)
截断误差: O ( h ) O(h) O(h) - 向后差商(Reverse Difference)
∂ f ( x ) ∂ x ≈ f ( x ) − f ( x − h ) h \frac{\partial f(x)}{\partial x} \approx \frac{f(x)-f(x-h)}{h} ∂x∂f(x)≈hf(x)−f(x−h)
截断误差: O ( h ) O(h) O(h) - 中心差商(Center Difference)
∂ f ( x ) ∂ x ≈ f ( x + h ) − f ( x − h ) 2 h \frac{\partial f(x)}{\partial x} \approx \frac{f(x+h)-f(x-h)}{2h} ∂x∂f(x)≈2hf(x+h)−f(x−h)
截断误差: O ( h 2 ) O(h^2) O(h2)
可以看出来,数值微分中的截断误差与步长 h h h 有关, h h h 越小则截断误差越小,近似程序越高。
但实际情况数值微分的精确度并不会随着 h h h 的减小而无限减小,因为计算机系统中对于浮点数的运算由于其表达方式存在另外一种误差(舍入误差,Round-off Error),而舍入误差则会随着 h h h 变小而逐渐增大。因此在截断误差和舍入误差的共同作用下,数值微分的精度将会形成一个变化的函数并在某一个 h h h 值处达到最小值。
因此数值微分的优缺点可以简单总结如下:
- 优点:简单易实现
- 缺点:存在精度误差问题
符号微分
符号微分 3(图 1 右侧中间)是通过一系列如下的数学规则对计算机程序中的表达式进行递归变换来完成求导。
∂ ∂ x ( f ( x ) + g ( x ) ) = ∂ ∂ x f ( x ) + ∂ ∂ x g ( x ) \frac {\partial}{\partial x}(f(x)+g(x))=\frac {\partial}{\partial x}f(x) + \frac {\partial}{\partial x}g(x) ∂x∂(f(x)+g(x))=∂x∂f(x)+∂x∂g(x)
∂ ∂ x ( f ( x ) g ( x ) ) = ( ∂ ∂ x f ( x ) ) g ( x ) + f ( x ) ( ∂ ∂ x g ( x ) ) \frac {\partial}{\partial x}(f(x)g(x))=(\frac {\partial}{\partial x}f(x))g(x) + f(x)(\frac {\partial}{\partial x}g(x)) ∂x∂(f(x)g(x))=(∂x∂f(x))g(x)+f(x)(∂x∂g(x))
由于变换过程中并不涉及计算且是严格等价,因此其可以大大减小微分结果的误差(仅存在变换完成后计算过程中的舍入误差)。除此之外,符号微分的计算方式使其还能用于类似极值 ∂ ∂ x f ( x ) = 0 \frac {\partial}{\partial x}f(x)=0