上篇文章介绍了如何实现VJP求导,本文则介绍一下如何实现自动微分中的JVP求导。
1 注册梯度
VJP求导是从后向前求导,实现较为简单,而JVP求导则是从前向后求导,它存在的一个问题在于很难以一种简洁的方法实现高阶求导,因为导数计算时还会注册梯度,而JVP从前向后的计算机制会很容易让程序陷入无限循环。为了解决这一问题,我们需要让每个JVPDiffArray是自完备的,即它要携带所有涉及到它自己的求导所需的信息:
def register_diff(self, func, args, kwargs):
"""
Register the information required for forward.
"""
try:
if func is np.ufunc.__call__:
jvpmaker = primitive_jvps[args[0]]
else:
jvpmaker = primitive_jvps[func]
except KeyError:
raise NotImplementedError("JVP of func not defined")
jvp_args, parents = [], []
for arg in args:
if isinstance(arg, JVPDiffArray):
jvp_args.append(arg)
parents.append(arg)
elif not isinstance(arg, np.ufunc):
jvp_args.append(arg)
parent_argnums = tuple(range(len(parents)))
jvps = list(jvpmaker(parent_argnums, self, tuple(jvp_args), kwargs))
if self._jvp is None:
self._jvp = {}
for p, jvp in zip(parents, jvps):
if p not in self._jvp:
self._jvp[p] = [[jvp]]
else:
self._jvp[p] += [[jvp]]
if p._jvp:
for base in p._jvp:
if base not in self._jvp:
self._jvp[base] = []
self._jvp[base] += [flist + [jvp] for flist in p._jvp[base]]
2 前向传播
在计算梯度时,因为JVPDiffArray自己已包含所有所需信息,所以只要计算已经存储好的jvp函数链即可:
def _forward(self, x, grad_variables):
if self._jvp is None:
return grad_variables
result = []
for flist in self._jvp[x]:
cur_result = grad_variables
for f in flist:
cur_result = f(cur_result)
result.append(cur_result)
return reduce(lambda x, y: x + y, result)
这里我们不再需要写一个单独的_forward_jacobian,直接将v的不同位置赋1之后进行前向计算即可:
def to(self, x, grad_variables=None, jacobian=False):
"""
Calculate the JVP or Jacobian matrix of self to x.
"""
if self._jvp and x not in self._jvp:
raise ValueError("Please check if the base is correct.")
if jacobian:
if self._jacobian is None:
self._jacobian = {}
if x not in self._jacobian:
self._jacobian[x] = {}
for position in itertools.product(*[range(i) for i in np.shape(x)]):
grad_variables = np.zeros_like(x)
grad_variables.value[position] = 1
self._jacobian[x][position] = self._forward(x, grad_variables)
old_axes = tuple(range(np.ndim(self) + np.ndim(x)))
new_axes = old_axes[np.ndim(x) :] + old_axes[: np.ndim(x)]
self._jacobian[x] = np.transpose(
np.reshape(
np.stack(self._jacobian[x].values()), np.shape(x) + np.shape(self),
),
new_axes,
)
return self._jacobian[x]
else:
if self._diff is None:
self._diff = {}
if x not in self._diff:
if grad_variables is None:
grad_variables = np.ones_like(self)
self._diff[x] = self._forward(x, grad_variables)
return self._diff[x]