pytorch中的hook
为了方便,我都是用单输入单输出的结构来测试的。
pytorch提供三种hook,一种用来看tensor的grad,一种用来看一个层的输入输出,还有一种用来看一个层输入输出的梯度。
所有的hook函数,都需要定义一个function供它使用。
tensor.register_hook(function)是最好理解的,能够获取的就是这个tensor的梯度。它需要的function格式为hook(grad),这里的grad就是tensor的梯度。
Conv2d.forward_hook(function)(当然其他一些layer也可以,这里就以Conv2d举例了),它需要的function格式为forward_hook(module,input,output)。这里的module就是这个Conv2d,input和output分别是前向传播时这个Conv2d的输入和输出。
Conv2d.backard_hook(function),它需要的function格式为backward_hook(module,input_grad,output_grad)。这里的input和output,是反向传播时来说的(也就是和forward时相反),下面说的输入和输出,也是按反向传播的方向来说的。其中input_grad里有三个grad,按顺序分别是输入tensor的grad,layer自身weight的grad和layer自身bias的grad;output_grad中只有一个grad,就是输出tensor的grad。
这些hook,如果只是要打印一下需要的参数,比较方便,如果想要存起来供其他工具使用(比如tensorboardX),那么最好弄一个list来存储它们。
按照网络正常执行的顺序,forward_hook会最先输出,然后按照反向传播的顺序,backward_hook和tensor上的hook会依次输出。
不需要使用hook时一定要释放,否则显存会炸掉的…
测试代码:
import torch
import torch.nn as nn
import numpy as np
conv1=nn.Conv2d(1,1,3,1,1,bias=False)
conv2=nn.Conv2d(1,1,3,1,1,bias=False)
conv1.weight.data=torch.tensor([[[[1,2