1. 输入要梯度,输出必须要梯度
我们只能指定计算图的leaf节点的requires_grad变量来决定改变量是否记录梯度,而不能指定它们运算产生的节点的requires_grad,它们是否要梯度取决于它们的输入节点,它们的输入节点只要有一个requires_grad是True,那么它的requires_grad也是True.
x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
2w.requires_grad = True
y = x @ w.t()
z = y @ w2.t()
print(y.requires_grad, z.requires_grad)
z.sum().backward()
2. 获得中间节点的梯度
对于叶节点,如果我们指定了梯度,我们可以调用v.grad查看梯度;但是对于中间变量v.grad永远是None,如果要获得其梯度,就要使用register_hook,它会在调用这个变量的梯度反传的时候调用注册的函数.以下是一个简单的查看版本
import torch
from torch import nn
def hook(grad):
print(grad)
x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()
y.register_hook(hook)
z.sum().backward() # invoke get_grad('y') here
改进版
import torch
class GradCollector(object):
def __init__(self):
self.grads = {}
def __call__(self, name: str):
def hook(grad):
self.grads[name] = grad
return hook
x = torch.randn(2, 100)
x.requires_grad = False
w = torch.randn(10, 100)
w2 = torch.randn(3, 10)
w.requires_grad = True
w2.requires_grad = True
y = x @ w.t()
z = y @ w2.t()
grad_collector = GradCollector()
y.register_hook(grad_collector("y"))
z.register_hook(grad_collector('z'))
z.sum().backward()
print(grad_collector.grads['y'])
print(grad_collector.grads['z'])