pytorch的那些坑:
提示:这里简述项目相关背景:
问题描述:
提示:pytorch里查看变量的梯度:
只有叶子节点pytorch才会保留梯度,其他节点是不会保存梯度的,所以查看非叶子节点的梯度时会显示为none,如果需要查看非叶子节点的梯度需要给这个变量加上retain_grad():
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)
> tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])