next_functions 反向传播计算图的记录。
在反向图的计算中,计算图终止于叶子AccumulateGrad节点。有一个.variable属性指向叶子节点。
例子:
a = torch.randn(1, requires_grad=True)
b = a*(a+2)
print (b.grad_fn.next_functions)
print (b.grad_fn.next_functions[1][0].next_functions)
print (b.grad_fn.next_functions[0][0].variable is a)
输出:
((<AccumulateGrad object at 0x7fbe7aa96780>, 0), (<AddBackward0 object at 0x7fbe7aa96748>, 0))
((<AccumulateGrad object at 0x7fbe7aa96780>, 0), (None, 0))
True
对于a*(a+2) 一个图分支是a,另一个图分支是a+2。第二个图分支含有一个a分支和一个常数2 分支
对反向传播的梯度进行处理
p_tmp = p.expand_as(p)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
hook_removehandle=grad_acc.register_hook(self._make_hook(p))
self._hooks_removableHandle.append(hook_removehandle)
#使用remove去除gradient上面的函数
hook_removehandle.remove()