摘要
自动微分(Automatic Differentiation,AD)是一种对计算机程序进行高效准确求导的技术,一直被广泛应用于计算流体力学、大气科学、工业设计仿真优化等领域。而近年来,机器学习技术的兴起也驱动着对自动微分技术的研究进入一个新的阶段。随着自动微分和其他微分技术研究的深入,其与编程语言、计算框架、编译器等领域的联系愈发紧密,从而衍生扩展出更通用的 可微编程 概念。本系列文章将对自动微分技术和可微编程的研究和发展进行概括综述。
本文将分为三部分,主要框架如下:
- 常见计算机程序求导方法介绍
- 业内自动微分和可微编程方案介绍
- 自动微分和可微编程待解决的问题和展望
当然如果读者想了解更多有关可微编程相关的技术内容,可以持续关注 SIG-可微编程发表的相关文章。我们也欢迎读者朋友加入我们的编程语言技术社区 SIG-可微编程小组,和我们一起深入探讨可微编程相关技术。
自动微分实现
在上一篇的文章中,我们介绍了自动微分的基本数学原理。可以总结自动微分的关键步骤为:
- 分解程序为一系列已知微分规则的基础表达式的组合
- 根据已知微分规则给出各基础表达式的微分结果
- 根据基础表达式间的数据依赖关系使用链式法则将微分结果组合完成程序的微分结果
虽然自动微分的数学原理已经明确,但具体的实现方法则可以有很大的差异。2018 年,Siskind 等学者在其综述论文 [1] 中对自动微分实现方案划分为三类:
- 基本表达式法:封装一系列基本的表达式(如:加减乘除等)及其对应的微分结果表达式作为库函数,用户通过调用库函数构建需要被微分的程序。而封装后的库函数在运行时会记录所有的基本表达式和相应的组合关系,最后使用链式法则对上述基本表达式的微分结果进行组合完成自动微分。
- 操作符重载法:利用现代语言的多态特性,使用操作符重载对语言中基本运算表达式的微分规则进行封装。类似地,重载后的操作符在运行时会记录所有的操作符和相应的组合关系,最后使用链式法则对上述基本表达式的微分结果进行组合完成自动微分。
- 代码变换法:通过对语言预处理器、编译器或解释器的扩展,将其中程序表达(如:源码、AST 或 IR)的基本表达式微分规则进行预定义,再对程序表达进行分析得到基本表达式的组合关系,最后使用链式法则对上述基本表达式的微分结果进行组合生成对应微分结果的新程序表达,完成自动微分。
以a = (x + y) / z
为例,下面介绍三种自动微分实现方法并分析它们的利弊。
基本表达式法
以 Wengert 等人在 1964 年提出的自动微分实现方法 [2] 为例,用户首先需要手动将上述函数分解为库函数中基本表达式组合:
t1 = x + y
a = t1 / z
使用给定的库函数完成上述函数的程序设计:
// 参数为变量 x,y,t1 和对应的导数变量 dx,dy,dt1
call ADAdd(x, dx, y, dy, t1, dt1)
// 参数为变量 t1,z,a 和对应的导数变量 dt1,dz,da
call ADDiv(t1, dt1, z, dz, a, da)
而库函数中则定义了对应表达式的数学微分规则和链式法则:
def ADAdd(x, dx, y, dy, z, dz):
z = x + y
dz = dy + dx
def ADDiv(x, dx, y, dy, z, dz):
z = x / y
dz = dx / y + (x / (y * y)) * dy
基本表达式法的优缺点可以总结如下:
- 优点:实现简单,基本可在任意语言中快速地实现为库
- 缺点:用户必须使用库函数进行编程,而无法使用语言原生的运算表达式
操作符重载法
以 2013 年 Shtof 等人 [3] 在 csharp 语言上开发的自动微分库 AutoDiff [4] 为例,该自动微分库预定义了特定的数据结构,并对该数据结构重载了相应的基本运算操作符。
namespace AutoDiff
{
public abstract class Term
{
// 重载操作符 `+`,`*` 和 `/`,调用这些操作符时,会通过其中的
// TermBuilder 将操作的类型、输入输出信息等记录至 tape 中
public static Term operator+(Term left, Term right)
{
return TermBuilder.Sum(left, right);
}
public static Term operator*(Term left, Term right)
{
return TermBuilder.Product(left, right);
}
public static Term operator/(Term numerator, Term denominator)
{
return TermBuilder.Product(numerator, TermBuilder.Power(denominator, -1));
}
}
}
当用户使用该数据类型中重载的表达式完成函数定义后,程序在实际执行时会将相应表达式的操作类型和输入输出信息记录至一个 tape 数据结构中。
using AutoDiff;
class Program
{
public static void Main(string[] args)
{
// 变量定义,注:Variable 是 Term 的子类型
var x = new Variable();