一、参考链接
https://blog.51cto.com/u_15047489/2619287
https://zhuanlan.zhihu.com/p/83172023
二、叶子节点与结果节点
pytorch用计算图(动态图,tensorflow是静态图)来描述计算,计算图分为节点和边两要素,节点为张量(分为叶子节点和结果节点,可以通过tensor.is_leaf查看,用户自己创建的都是叶子节点),计算(卷积、加乘)为边。叶子节点和结果节点的意义在于反向传播时叶子结点的梯度会保存,可以通过tensor.grad查看,而结果节点的梯度因为节约内存会自动释放,查看时结果为None,所以如果想要保存结果节点的梯度可以使参数retain_grad() =True。
1、当backward()后销毁计算图
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
z = y * 3
y.backward()
print(a.grad) # 2
z.backward(retain_graph=True)
print(a.grad)
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Process finished with exit code 1
2、通过backward(retain_graph)保留计算图
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
z = y * 3
y.backward(retain_graph=True)
print(a.grad) # 2
z.backward(retain_graph=True)
print(a.grad)
tensor(2.)
tensor(8.)
3、为了节约内存,中间变量的grad不会保存
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
z = y * 3
z.backward(retain_graph=True)
print(y.grad)
print(a.grad)
UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won’t be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /opt/conda/conda-bld/pytorch_1666643016022/work/build/aten/src/ATen/core/TensorBody.h:480.)
print(y.grad)
None
tensor(6.)
4、通过设置retain_grad()保存中间变量的grad
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
print(y.retains_grad)
y.retain_grad()
print(y.retains_grad)
z = y * 3
z.backward(retain_graph=True)
print(y.grad)
print(a.grad)
False
True
tensor(3.)
tensor(6.)
三、torch.autograd.backward
torch.autograd.backward()等同于tensor.backward(),功能是对所有叶子节点求梯度,参数grad_tensors是计算梯度时用到,因为只有标量才可以求梯度,所以如果tensor是向量或者矩阵时,需要利用grad_tensors将其转换为标量,所以grad_tensors相当于各个元素的权重,形状与tensor相同。另外,张量求导之后,计算图会自动销毁即不能再次backward求导,此时需要利用参数retain_graph来保存计算图以再次求导。
四、torch.autograd.grad
求输出对输入的导数,可以指定对每个输入求导,也可以对所有叶子节点求导(通过only_inputs,默认是True)。