一、Hook函数概念
Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。
Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;
参考:知乎大佬的讲解
主要分为:
- Hook for Tensors :针对 Tensor 的 hook
- Hook for Modules:针对例如
nn.Conv2d
nn.Linear
等网络模块的 hook - Guided Backpropagation:利用 Hook 实现的一段神经网络可视化代码
首先看一下针对张量的操作!
1、Tensor.register_hook
功能:注册一个反向传播hook函数
hook函数仅一个输入参数,为张量的梯度;
以计算图与梯度求导为例:
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1)
#-------tensor hook---------------
# y = (x+w)*(w+1)
# a = (x+w), b = (w+1)
w = torch.tensor([1.],requires_grad=True)
x = torch.tensor([2.],requires_grad=True)
a = torch.add(w,x)
b = torch.add(w,1)
y = torch.mul(a,b)
a_grad = list()
def grad_hook(grad):
a_grad.append(grad)
handle = a.register_hook(grad_hook)
y.backward()
print("gradient:",w.grad,x.grad,a.grad,b.grad,y.grad)
print("a_grad[0]:",a_grad[0])
gradient: tensor([5.]) tensor([2.]) None None None a_grad[0]: tensor([2.])
#-----对梯度进行修改操作----------------------------------
# y = (x+w)*(w+1)
w = torch.tensor([1.],requires_grad=True)
x = torch.tensor([2.],requires_grad=True)
a = torch.add(w,x)
b = torch.add(w,1)
y = torch.mul(a,b)
a_grad = list()
# 相比上述梯度而言,这里将梯度值乘以2
def mul_grad_hook(grad):
grad *= 2 #w=10
return grad*3 #w=30
handle = w.register_hook(mul_grad_hook)
y.backward()
print("w.grad:",w.grad)
handle.remove()
w.grad: tensor([30.])
下面就来看针对module的hook函数;
2、Module.register_forward_hook
3、Module.register_forward_pre_hook
功能:注册module前向传播前的hook函数
主要参数:
- module:当前网络层
- input:当前网络层输入数据
4、Module.register_backward_hook
功能:注册module反向传播的hook函数
参数:
- module:当前网络层
- grad_input:当前网络层输入梯度数据
- grad_output:当前网络层输出梯度数据
网络模块 module 不像上一节中的 Tensor,拥有显式的变量名可以直接访问,而是被封装在神经网络中间。我们通常只能获得网络整体的输入和输出,对于夹在网络中间的模块,我们不但很难得知它输入/输出的梯度,甚至连它输入输出的数值都无法获得。
为了解决这个麻烦,PyTorch 设计了两种 hook:register_forward_hook
和 register_backward_hook
,分别用来获取正/反向传播时,中间层模块输入和输出的 feature/gradient,大大降低了获取模型内部信息流的难度。
功能:注册module的前向传播hook函数;
hook(module,input,output)
- module:当前网络层
- input:当前网络层输入数据
- output:当前网络层输出数据
以一个实例来看一下:
#---------------module 中的hook函数-------------------
# define a simple Net
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
# in_channel,out_channel,kernel_size
self.conv1 = nn.Conv2d(1,2,3)
# 卷积核尺寸
self.pool1 = nn.MaxPool2d(2,2)
def forward(self,x):
x = self.conv1(x)
x = self.pool1(x)
return x
def forward_hook(module,data_input,data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module,data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module,grad_input,grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
#初始化网络
net = Net()
# detach()就是为例截断反向传播的梯度流
# 卷积核1:值为1
net.conv1.weight[0].detach().fill_(1)
# 卷积核2:值为2
net.conv1.weight[1].detach().fill_(2)
# 偏置
net.conv1.bias.data.detach().zero_()
# 注册hook函数
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# batch size * channel * H * W
fake_img = torch.ones((1,1,4,4))
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]]]),) backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000]]], [[[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000])) backward hook output:(tensor([[[[0.5000, 0.0000], [0.0000, 0.0000]], [[0.5000, 0.0000], [0.0000, 0.0000]]]]),) -------------------------------------------------------------------------------------------------------------------- output shape: torch.Size([1, 2, 1, 1]) output value: tensor([[[[ 9.]], [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>) feature maps shape: torch.Size([1, 2, 2, 2]) output value: tensor([[[[ 9., 9.], [ 9., 9.]], [[18., 18.], [18., 18.]]]], grad_fn=<ThnnConv2DBackward>) input shape: torch.Size([1, 1, 4, 4]) input value: (tensor([[[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]]]),)