pytorch 的 hook 函数分为 torch.Tensor 和 torch.nn.Module 两类, 分别对应
torch.Tensor.register_hook
torch.nn.Module.register_backward_hook
常用的是第一种, 针对 torch.Tensor 使用
import torch
grad_list = []
def print_grad(grad):
grad_list.append(grad)
print(grad_list)
x = torch.randn(2, 1, requires_grad=True)
y = x + 2
z = torch.mean(torch.pow(y, 2))
y.register_hook(print_grad) # 这里要注意, register_hook 接收的是一个函数
z.backward()
register_hook(grad)