一、Hook函数概念
1.1 Hook引入的原因
Pytorch的运行机制是动态计算图,动态图运算结束后,一些中间变量(如feature map和非叶子结点的梯度)会被释放掉,但是往往有时候我们需要获取这些中间变量,这时就可以通过Hook函数在主体中根据Hook机制添加额外的函数来获取或改变中间变量
1.2 Hook函数机制
Hook函数机制: 不改变主体(前向传播和后向传播),实现额外功能,像一个挂件,挂钩, hook
nn.module中的call()函数的运行机制也正是hook函数机制,整个call函数分为四个部分,分别是:
- forward_pre_hook
- forward
- forward_hook
- backward_hook
如上图所示,call()函数执行forward_pre_hook函数,然后执行forward前向传播过程,接着执行forward_hook函数,最后执行back_forward函数
所以,在前向传播过程中,不仅仅只是单纯地执行前项传播,而是会提供hook函数接口,来实现额外的操作和功能
1.3 四种hook函数
主要分为三类:针对tensor的,前向传播的,和后向传播的
- torch.Tensor.register_hook(hook)
- torch.nn.Module.register_forward_hook
- torch.nn.Module.register_forward_pre_hook
- torch.nn.Module.register_backward hook
二、Hook函数与特征提取
2.1 Tensor.register_hook
hook(grad)
功能: 注册一个反向传播hook函数
Hook函数仅一个输入参数,为张量的梯度,返回张量或者无返回
示例:通过hook函数获取和改变非叶子结点的梯度
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list() # 存储张量的梯度
def grad_hook(grad):
a_grad.append(grad)
handle = a.register_hook(grad_hook) # 把定义的函数注册到对应的张量上
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad): # 定义hook函数修改张量梯度
grad *= 2
return grad*3 # 通过return返回的梯度会覆盖掉原梯度
handle