Pytorch中的hook函数

一、Hook函数概念

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。

Hook函数机制:不改变主体,实现额外的功能,像一个挂件一样;

参考:知乎大佬的讲解

主要分为:

  1. Hook for Tensors :针对 Tensor 的 hook
  2. Hook for Modules:针对例如 nn.Conv2dnn.Linear等网络模块的 hook
  3. Guided Backpropagation:利用 Hook 实现的一段神经网络可视化代码

首先看一下针对张量的操作!

1、Tensor.register_hook

功能:注册一个反向传播hook函数

hook函数仅一个输入参数,为张量的梯度;

以计算图与梯度求导为例:y = (x + w)*(w+1)

import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1)

#-------tensor hook---------------
# y = (x+w)*(w+1)
# a = (x+w), b = (w+1)
w = torch.tensor([1.],requires_grad=True)
x = torch.tensor([2.],requires_grad=True)
a = torch.add(w,x)
b = torch.add(w,1)
y = torch.mul(a,b)
a_grad = list()
def grad_hook(grad):
    a_grad.append(grad)
    
handle = a.register_hook(grad_hook)
y.backward()
print("gradient:",w.grad,x.grad,a.grad,b.grad,y.grad)
print("a_grad[0]:",a_grad[0])
gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]: tensor([2.])
#-----对梯度进行修改操作----------------------------------
# y = (x+w)*(w+1)
w = torch.tensor([1.],requires_grad=True)
x = torch.tensor([2.],requires_grad=True)
a = torch.add(w,x)
b = torch.add(w,1)
y = torch.mul(a,b)
a_grad = list()
# 相比上述梯度而言,这里将梯度值乘以2
def mul_grad_hook(grad):
    grad *= 2  #w=10
    return grad*3  #w=30

handle = w.register_hook(mul_grad_hook)
y.backward()
print("w.grad:",w.grad)
handle.remove()
w.grad: tensor([30.])

下面就来看针对module的hook函数;

2、Module.register_forward_hook

3、Module.register_forward_pre_hook

功能:注册module前向传播前的hook函数

主要参数:

  • module:当前网络层
  • input:当前网络层输入数据

4、Module.register_backward_hook

功能:注册module反向传播的hook函数

参数:

  • module:当前网络层
  • grad_input:当前网络层输入梯度数据
  • grad_output:当前网络层输出梯度数据

网络模块 module 不像上一节中的 Tensor,拥有显式的变量名可以直接访问,而是被封装在神经网络中间。我们通常只能获得网络整体的输入和输出,对于夹在网络中间的模块,我们不但很难得知它输入/输出的梯度,甚至连它输入输出的数值都无法获得。

为了解决这个麻烦,PyTorch 设计了两种 hook:register_forward_hook 和 register_backward_hook,分别用来获取正/反向传播时,中间层模块输入和输出的 feature/gradient,大大降低了获取模型内部信息流的难度。

功能:注册module的前向传播hook函数;

hook(module,input,output)

  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

以一个实例来看一下:

#---------------module 中的hook函数-------------------
# define a simple Net
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # in_channel,out_channel,kernel_size
        self.conv1 = nn.Conv2d(1,2,3)
        # 卷积核尺寸
        self.pool1 = nn.MaxPool2d(2,2)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x
    
def forward_hook(module,data_input,data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)
    
def forward_pre_hook(module,data_input):
    print("forward_pre_hook input:{}".format(data_input))

def backward_hook(module,grad_input,grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))

#初始化网络
net = Net()
# detach()就是为例截断反向传播的梯度流
# 卷积核1:值为1
net.conv1.weight[0].detach().fill_(1)
# 卷积核2:值为2
net.conv1.weight[1].detach().fill_(2)
# 偏置
net.conv1.bias.data.detach().zero_()

# 注册hook函数
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

# batch size * channel * H * W
fake_img = torch.ones((1,1,4,4))
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))
forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],

         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)
--------------------------------------------------------------------------------------------------------------------
output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],

         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)

feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],

         [[18., 18.],
          [18., 18.]]]], grad_fn=<ThnnConv2DBackward>)

input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)

 

 

 

 

 

  • 4
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
PyTorchhook机制是一种用于在计算图注册回调函数的机制。当计算图被执行时,这些回调函数会被调用,并且可以对计算图间结果进行操作或记录。 在PyTorch,每个张量都有一个grad_fn属性,该属性表示该张量是如何计算得到的。通过在这个grad_fn上注册一个hook函数,可以在计算图的每一步获取该张量的梯度,或者在该张量被使用时获取该张量的值。这些hook函数可以被用来实现一些调试、可视化或者改变计算图的操作。 下面是一个简单的例子,其我们在计算图的每一步都打印出间结果和梯度: ```python import torch def print_tensor_info(tensor): print('Tensor shape:', tensor.shape) print('Tensor value:', tensor) print('Tensor gradient:', tensor.grad) x = torch.randn(2, 2, requires_grad=True) y = x * 2 z = y.mean() # 注册一个hook函数,用来打印间结果和梯度 y.register_hook(print_tensor_info) # 执行计算图 z.backward() # 输出结果 print('x gradient:', x.grad) ``` 在这个例子,我们定义了一个张量x,并计算了y和z。我们在y上注册了一个hook函数,该函数在计算图的每一步都会被调用。然后我们执行了z的反向传播,计算出了x的梯度。最后,我们打印出了x的梯度。 需要注意的是,hook函数不能修改张量的值或梯度,否则会影响计算图的正确性。此外,hook函数只会在计算图的正向传播和反向传播时被调用,而不会在张量被直接使用时被调用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

kaichu2

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值