一、Hook函数概念
Hook函数机制:不改变模型主体,实现额外功能,像一个挂件或挂钩等。
为什么需要这个函数呢?这与Pytorch的动态图计算机制有关,在动态图的计算过程中,一些中间变量会释放掉,比如特征图、非叶子节点的梯度,在模型前向传播、反向传播的时候添加hook这个额外函数,提取一些释放掉而后面又需要用到的变量,也可以用hook函数来改变中间变量的梯度。
Pytorch中提供四种hook函数:
1、torch.Tensor.register_hook(hook): 针对tensor
2、torch.nn.Module.register_forward_hook:后面这三个针对Module
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook
二、Hook函数与特征提取
1、torch.Tensor.register_hook()
功能:这是一个针对张量的hook函数,作用是注册一个反向传播的hook函数,为什么是在反向传播呢?因为只有在反向传播过程中非叶子的梯度会释放掉,用hook函数来保存这些中间变量的信息。
hook(grad) -> Tensor or None
hook函数仅有一个输入参数为张量的梯度,返回值是tensor或者none
例如:
下图是pytorch中一个简单的计算图与梯度求导
在上面计算图反响传播过程中,非叶子节点a和b的梯度会释放掉,在前面的学习中可知retain_grad()可保留参数的梯度,也可用hook函数来保留梯度,如下所示:
# 构建计算图,在反向传播中用hook来保存a的梯度
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 构建一个list用来存储a的梯度
a_grad = list()
# 自定义hook函数,存放a的梯度,然后将a的梯度存放到前面构建的list中
def grad_hook(grad):
a_grad.append(grad)
# 接受一个hook函数的钩子,相当于把hook函数挂到计算图上,这样在反向传播时可以保存a的梯度
handle = a.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print(