Pytorch获取中间变量的梯度
为了节省显存,pytorch在反向传播的过程中只保留了计算图中的叶子结点的梯度值,而未保留中间节点的梯度,如下例所示:import torchx = torch.tensor(3., requires_grad=True)y = x ** 2z = 4 * yz.backward()print(x.grad) # tensor(24.)print(y.grad) # None可以看到当进行反向传播后,只保留了x的梯度tensor(24.),而y的梯度没有保留所以为None。








