python从入门到入土_自动微分从入门到入土(二):实现VJP

本文介绍一下如何从vjp的角度出发构建一个自动微分框架。

1 基本vjp微分算子

vjp微分算子是从vjp角度构建自动微分的基石。因为部分微分算子的构建过于复杂,而且容易出错,我们直接采用autograd框架中的vjp微分算子的定义方法。一个简单的二元微分算子如下:

defvjp(

np.subtract,

lambda ans, x, y : unbroadcast_f(x, lambda g: g),

lambda ans, x, y : unbroadcast_f(y, lambda g: -g)

)

代码中的defvjp注册二元函数np.subtract对应的两个微分算子,每个微分计算都需要输入至少四个变量,分两个阶段输入。

第一个阶段(计算图构建阶段)输入计算的结果ans,两个输入x和y,此外,一些其他的可能对求导有影响的参数,如np.sum的axis等也要在这个阶段输入。第二个阶段(反向传播)输入该结点的梯度g。

2 计算图构建

我们可以将任意计算组织成一张有向无环图的形式。如公式f\left(x_{1}, x_{2}\right)=\ln \left(x_{1}\right)+x_{1} x_{2}-\sin \left(x_{2}\right), 假设x1=2, x2=5, 计算图可表示如下:

cbac1f6366e7b78f5995f4a00ec2006e.png

为了后续的反向传播,我们需要利用一种数据结构将计算图中的每个结点保留下来,假设这种数据结构为VJPDiffArray。它应该具有以下属性:

id: 结点的id

value: 结点的数值, 如ln2

parents: 指向当前结点的其他结点,如结点v4的parents就是[v2, v3]

vjp: 在前向传播时将ans、x、y等输入微分算子得到的求导函数,只要在后向传播时再输入g就可以得到当前结点对parents梯度。

所有的这些属性在前向传播时通过一个叫register_diff的函数进行赋值。

def register_diff(self, func, args, kwargs):

"""

Register the derivative function used in backward propagation

for the current node.

"""

try:

# Get the corresponding differential operator.

if func is np.ufunc.__call__:

vjpmaker = primitive_vjps[args[0]]

else:

vjpmaker = primitive_vjps[func]

except KeyError:

raise NotImplementedError("VJP of func not defined")

vjp_args = []

if self._parents is None:

self._parents = []

for arg in args:

if isinstance(arg, VJPDiffArray):

# Register parents

self._parents.append(arg)

vjp_args.append(arg)

elif not isinstance(arg, np.ufunc):

vjp_args.append(arg)

parent_argnums = tuple(range(len(self._parents)))

# Input ans, x, y to get inner func

self._vjp = vjpmaker(parent_argnums, self, tuple(vjp_args), kwargs)

3 反向传播

在构建好计算图后,求导的过程就比较简单了(虽然我一开始将它想的很复杂),我们可以将其表达成一个图上的反向传播过程:

a8a19c9d38af2a5aba2d202ae6871f81.png

对于每一个结点,我们输入上一个结点的梯度对当前结点的梯度,获得当前结点对于parents的梯度,这个过程通过_backward函数完成:

def _backward(self, grad_variables, end_node, base):

"""

Backpropagation.

Traverse computation graph backwards in topological order from the end node.

For each node, compute local gradient contribution and accumulate.

"""

if grad_variables is None:

# For the starting node v6, the input grad_variables

# represents the weight of each component.

# If there is no input, initialize to 1.

grad_variables = np.ones_like(self.value)

if end_node is None:

end_node = self

if base is None or base.id == self.id:

if self._diff is None:

self._diff = {}

if end_node in self._diff:

self._diff[end_node] = self._diff[end_node] + grad_variables

else:

self._diff[end_node] = grad_variables

if self._vjp:

diffs = list(self._vjp(grad_variables))

for i, p in enumerate(self._parents):

p._backward(diffs[i], end_node, base)

4 求Jacobian

在完成上面的两个函数后,我们基于此快速实现获取真实的jacobian矩阵。因为vjp就是np.sum(y*v)对x求导,所以jacobian矩阵j上的一部分j[i][j](代表y[i][j]对x的导数)可以看作是使用用一个仅[i,j]位置为1,其他位置都为0的矩阵v乘y对x的vjp。

实现如下:

def _backward_jacobian(self, grad_variables, end_node, position, base):

if base is None or base.id == self.id:

if self._jacobian is None:

self._jacobian = {}

if end_node not in self._jacobian:

self._jacobian[end_node] = {}

if position not in self._jacobian[end_node]:

self._jacobian[end_node][position] = grad_variables

else:

self._jacobian[end_node][position] = (

self._jacobian[end_node][position] + grad_variables

)

if self._vjp:

diffs = list(self._vjp(grad_variables))

for i, p in enumerate(self._parents):

p._backward_jacobian(diffs[i], end_node, position, base)

最后提供一个对外的调用APIto:

def to(self, x, grad_variables=None, jacobian=False):

"""

Calculate the VJP or Jacobian matrix of self to x.

"""

if jacobian:

if x._jacobian is None or self not in x._jacobian:

for position in itertools.product(*[range(i) for i in np.shape(self)]):

grad_variables = np.zeros_like(self.value)

grad_variables.value[position] = 1

self._backward_jacobian(grad_variables, self, position, x)

x._jacobian[self] = np.reshape(

np.stack(x._jacobian[self].values()), np.shape(self) + np.shape(x)

)

return x._jacobian[self]

else:

if x._diff is None or self not in x._diff:

self._backward(grad_variables, self, x)

return x._diff[self]

注意因为我们在反向传播进行计算时也会注册梯度(调用register_diff),所以在vjp模式中获取高阶微分只需要反复调用to函数即可。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值