def get_children(tensor):
if tensor.grad_fn is None:
return []
return [nf.variable for nf, _ in tensor.grad_fn.next_functions]
PyTorch 如何获取 tensor 反向传播的子节点
最新推荐文章于 2024-10-20 19:52:38 发布
def get_children(tensor):
if tensor.grad_fn is None:
return []
return [nf.variable for nf, _ in tensor.grad_fn.next_functions]