本文介绍一下如何从vjp的角度出发构建一个自动微分框架。
1 基本vjp微分算子
vjp微分算子是从vjp角度构建自动微分的基石。因为部分微分算子的构建过于复杂,而且容易出错,我们直接采用autograd框架中的vjp微分算子的定义方法。一个简单的二元微分算子如下:
defvjp(
np.subtract,
lambda ans, x, y : unbroadcast_f(x, lambda g: g),
lambda ans, x, y : unbroadcast_f(y, lambda g: -g)
)
代码中的defvjp注册二元函数np.subtract对应的两个微分算子,每个微分计算都需要输入至少四个变量,分两个阶段输入。
第一个阶段(计算图构建阶段)输入计算的结果ans,两个输入x和y,此外,一些其他的可能对求导有影响的参数,如np.sum的axis等也要在这个阶段输入。第二个阶段(反向传播)输入该结点的梯度g。
2 计算图构建
我们可以将任意计算组织成一张有向无环图的形式。如公式f\left(x_{1}, x_{2}\right)=\ln \left(x_{1}\right)+x_{1} x_{2}-\sin \left(x_{2}\right), 假设x1=2, x2=5, 计算图可表示如下:
为了后续的反向传播,我们需要利用一种数据结构将计算图中的每个结点保留下来,假设这种数据结构为VJPDiffArray。它应该具有以下属性:
id: 结点的id
value: 结点的数值, 如ln2
parents: 指向当前结点的其他结点,如结点v4的parents就是[v2, v3]
vjp: 在前向传播时将ans、x、y等输入微分算子得到的求导函数,只要在后向传播时再输入g就可以得到当前结点对parents梯度。
所有的这些属性在前向传播时通过一个叫register_diff的函数进行赋值。
def register_diff(self, func, args, kwargs):
"""
Register the derivative function used in backward propagation
for the current node.
"""
try:
# Get the corresponding differential operator.
if func is np.ufunc.__call__:
vjpmaker = primitive_vjps[args[0]]
else:
vjpmaker = primitive_vjps[func]
except KeyError:
raise NotImplementedError("VJP of func not defined")
vjp_args = []
if self._parents is None:
self._parents = []
for arg in args:
if isinstance(arg, VJPDiffArray):
# Register parents
self._parents.append(arg)
vjp_args.append(arg)
elif not isinstance(arg, np.ufunc):
vjp_args.append(arg)
parent_argnums = tuple(range(len(self._parents)))
# Input ans, x, y to get inner func
self._vjp = vjpmaker(parent_argnums, self, tuple(vjp_args), kwargs)
3 反向传播
在构建好计算图后,求导的过程就比较简单了(虽然我一开始将它想的很复杂),我们可以将其表达成一个图上的反向传播过程:
对于每一个结点,我们输入上一个结点的梯度对当前结点的梯度,获得当前结点对于parents的梯度,这个过程通过_backward函数完成:
def _backward(self, grad_variables, end_node, base):
"""
Backpropagation.
Traverse computation graph backwards in topological order from the end node.
For each node, compute local gradient contribution and accumulate.
"""
if grad_variables is None:
# For the starting node v6, the input grad_variables
# represents the weight of each component.
# If there is no input, initialize to 1.
grad_variables = np.ones_like(self.value)
if end_node is None:
end_node = self
if base is None or base.id == self.id:
if self._diff is None:
self._diff = {}
if end_node in self._diff:
self._diff[end_node] = self._diff[end_node] + grad_variables
else:
self._diff[end_node] = grad_variables
if self._vjp:
diffs = list(self._vjp(grad_variables))
for i, p in enumerate(self._parents):
p._backward(diffs[i], end_node, base)
4 求Jacobian
在完成上面的两个函数后,我们基于此快速实现获取真实的jacobian矩阵。因为vjp就是np.sum(y*v)对x求导,所以jacobian矩阵j上的一部分j[i][j](代表y[i][j]对x的导数)可以看作是使用用一个仅[i,j]位置为1,其他位置都为0的矩阵v乘y对x的vjp。
实现如下:
def _backward_jacobian(self, grad_variables, end_node, position, base):
if base is None or base.id == self.id:
if self._jacobian is None:
self._jacobian = {}
if end_node not in self._jacobian:
self._jacobian[end_node] = {}
if position not in self._jacobian[end_node]:
self._jacobian[end_node][position] = grad_variables
else:
self._jacobian[end_node][position] = (
self._jacobian[end_node][position] + grad_variables
)
if self._vjp:
diffs = list(self._vjp(grad_variables))
for i, p in enumerate(self._parents):
p._backward_jacobian(diffs[i], end_node, position, base)
最后提供一个对外的调用APIto:
def to(self, x, grad_variables=None, jacobian=False):
"""
Calculate the VJP or Jacobian matrix of self to x.
"""
if jacobian:
if x._jacobian is None or self not in x._jacobian:
for position in itertools.product(*[range(i) for i in np.shape(self)]):
grad_variables = np.zeros_like(self.value)
grad_variables.value[position] = 1
self._backward_jacobian(grad_variables, self, position, x)
x._jacobian[self] = np.reshape(
np.stack(x._jacobian[self].values()), np.shape(self) + np.shape(x)
)
return x._jacobian[self]
else:
if x._diff is None or self not in x._diff:
self._backward(grad_variables, self, x)
return x._diff[self]
注意因为我们在反向传播进行计算时也会注册梯度(调用register_diff),所以在vjp模式中获取高阶微分只需要反复调用to函数即可。