我们知道,在计算图运行的过程中,中间变量的数据是不会被保留的。想要保存中间变量,一个方法是使用 tnesor.retain_grad = True
,但是这样会使数据被永久保存,造成存储空间的消耗,即我们可能只需要用到中间变量的值一次,但它却被永久保存了。为了解决这个问题,PyTorch 引入了 hook 函数。hook 函数分为张量的 hook 函数和神经网络层的 hook 函数两种。
1. 张量的 hook 函数
Tensor.register_hook(hook)
这个 hook 函数在每次计算反向传播的时候会被调用。tensor 的 hook 函数有个性质:
hook(grad) -> Tensor or None
即如果 hook 函数不返回任何值,张量的导数不会改变;如果有返回值,张量原本的导数会被返回值覆盖。张量的 hook 函数的使用方法如下:
- 定义计算图;
- 定义 hook 函数;
- 注册 hook 函数;
- 反向传播;
- 移除 hook 函数。
我们还是使用 PyTorch 折桂 4:torch.autograph 中计算导数的例子:
w = torch.tensor([1.], requires_grad=True) # 定义计算图
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad): # 定义 hook 函数
a_grad.append(grad)
return grad * 3
handle = w.register_hook(grad_hook) # 注册 hook 函数
y.backward() # 反向传播
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
print("w.grad: ", w.grad)
handle.remove() # 移除 hook 函数
运行后 PyTorch 先抛出来一个警告,说试图获得非叶子节点的导数。然后结果如下:
gradient: tensor([15.]) tensor([2.]) None None None
a_grad[0]: tensor([5.])
w.grad: tensor([15.])
可以看到,本来 a
的导数已经被释放掉了,但是我们通过 hook 函数把 a
的导数保存了起来。由于有返回值,导致 w
的导数被覆盖了。
2. 神经网络的 hook 函数
神经网络的 hook 函数分为前向传播的 hook 函数和反向传播的 hook 函数,分别为 register_forward_hook(hook)
和 register_backward_hook(hook)
,使用方法与张量的 hook 函数大同小异。我们还是来看一个例子:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(5, 1)
def forward(self, x):
x = self.fc(x)
return x
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
# 初始化网络
net = Net()
net.fc.weight.detach().fill_(2)
net.fc.bias.data.detach().zero_()
# 注册hook
fmap_block = list()
input_block = list()
net.fc.register_forward_hook(forward_hook)
net.fc.register_backward_hook(backward_hook)
# inference
x = torch.ones(5)
output = net(x)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
这里我们定义了一个单层全连接网络,全连接网络的权重为 2,使用 l1 损失函数,获得的全连接层的反向传播导数为:
backward hook input:(tensor([1.]), tensor([1.]))
backward hook output:(tensor([1.]),)