1 概念介绍
在pytorch的tensor类中,有个is_leaf的属性,表示这个tensor是否是叶子节点:is_leaf 为False的时候,则不是叶子节点, is_leaf为True的时候为叶子节点(有的地方也叫做叶子张量)
1.1 为什么要叶子节点?
对于tensor中的 requires_grad()属性,当requires_grad()为True时我们将会记录tensor的运算过程并为自动求导做准备。
但是并不是每个requires_grad()设为True的值都会在backward的时候得到相应的grad。它还必须为leaf。
这就说明. leaf成为了在 requires_grad()下判断是否需要保留 grad的前提条件
- 提出叶子张量的原因是为了节省内存/显存
- 那些非叶子结点是通过用户所定义的叶子节点的一系列运算生成的(也即中间变量)
- 一般情况下,用户不会去使用这些中间变量的导数,所以为了节省内存,它们在用完之后就被释放了
1.2 哪些张量是叶子张量?
- 所有requires_grad为False的张量(Tensor) 都为叶张量( leaf Tensor)
x_=torch.arange(10,dtype=torch.float32).reshape(10,1) x_.is_leaf #True
- requires_grad为True的张量(Tensor),如果他们是由用户创建的,则它们是叶张量(leaf Tensor).这意味着它们不是运算的结果,因此gra_fn为None
xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1) xx.is_leaf #False
xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1) ww=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(1,10) yy=ww@xx yy.backward() xx.grad,ww.grad #(None, None) ''' UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. '''
- 只有是叶张量的tensor在反向传播时才会将本身的grad传入的backward的运算中.。如果想得到当前自己创建的,requires_grad为True的tensor在反向传播时的grad, 可以用retain_grad()这个属性
xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1) ww=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(1,10) yy=ww@xx xx.retain_grad() ww.retain_grad() yy.backward() xx.grad,ww.grad ''' (tensor([[0.], [1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.]]), tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])) '''
2 保存中间变量的梯度
- 如果需要保留中间变量的导数,那么可以使用tensor.retain_grad()
- 哪一个张量需要保存,哪一个张量加上retain_grad()
loss = l4.mean()
l4.retain_grad()
loss.backward()
print(l4.grad)
3 输出中间变量的梯度
如果我们只是想进行 debug,只需要输出中间变量的导数信息,而不需要保存它们,我们还可以使用 tensor.register_hook
loss = l4.mean()
l4.register_hook(lambda grad: print('l4 grad:', grad))
loss.backward()