Pytorch中的register_hook(梯度操作)

对于高阶调参师而言,对神经网络梯度级别的操作的不可避免的。有时候,咱们需要把某一层的梯度拿出来分析,辅助特征图可视化(如GradCAM);再比如,hook还可以做优化器设计的实验。

hook,在中文里就是“钩子”的意思。Pytorch默认在反向传播过程中,不保留中间层的梯度,以达到减少内存的目的。如果需要对某一层的梯度特别感兴趣,咱们可以用钩子把它勾住,再等反向传播的时候,就可以访问到这一层的梯度。甚至可以对这层梯度进行修改,进而影响浅层的梯度传播

我们先看官方给出的示例:

import torch
v = torch.tensor([0., 0., 0.], requires_grad=True)
h = v.register_hook(lambda grad: grad * 2)  # double the gradient
v.backward(torch.tensor([1., 2., 3.]))
print(v.grad)
h.remove()

这个栗子主要想表达register_hook可以用来修改梯度,register_hook的作用在v.backward()那里开始执行(众所周知,只有反向传播的时候才有梯度)。如第三行所表达的,在张量v这里对梯度放大2倍。
lambda其实就是一个无命名函数的表达,不了解的可以戳《python中的lambda关键字》。
通常,使用register_hook,一般形式是

tensor.register_hook(func)

这个func就是函数名,表示你想对这层的张量进行的操作。如果你只是想读取这一层的张量,那么你可以在func中用一个变量去缓存张量即可,比如:

cache = []
def func(grad):
    cache.append(grad)
    return grad

咱们可以在cache中获取到梯度了。因为即使是hook住的张量梯度,也会在运行结束后立刻释放掉,所以想要保存梯度,一定要在func函数里进行


在神经网络中,咱们可以通过层的命名来获取该层的输出张量。比如对于检测器来讲,一般是backbone+head形式,假设咱们想获取backbone的梯度,可以考虑使用register_hook勾住backbone的输出张量,然后通过类似的方法保存住。

参考:https://pytorch.org/docs/stable/generated/torch.Tensor.register_hook.html
https://zhuanlan.zhihu.com/p/267130090

PyTorchhook机制是一种用于在计算图注册回调函数的机制。当计算图被执行时,这些回调函数会被调用,并且可以对计算图间结果进行操作或记录。 在PyTorch,每个张量都有一个grad_fn属性,该属性表示该张量是如何计算得到的。通过在这个grad_fn上注册一个hook函数,可以在计算图的每一步获取该张量的梯度,或者在该张量被使用时获取该张量的值。这些hook函数可以被用来实现一些调试、可视化或者改变计算图的操作。 下面是一个简单的例子,其我们在计算图的每一步都打印出间结果和梯度: ```python import torch def print_tensor_info(tensor): print('Tensor shape:', tensor.shape) print('Tensor value:', tensor) print('Tensor gradient:', tensor.grad) x = torch.randn(2, 2, requires_grad=True) y = x * 2 z = y.mean() # 注册一个hook函数,用来打印间结果和梯度 y.register_hook(print_tensor_info) # 执行计算图 z.backward() # 输出结果 print('x gradient:', x.grad) ``` 在这个例子,我们定义了一个张量x,并计算了y和z。我们在y上注册了一个hook函数,该函数在计算图的每一步都会被调用。然后我们执行了z的反向传播,计算出了x的梯度。最后,我们打印出了x的梯度。 需要注意的是,hook函数不能修改张量的值或梯度,否则会影响计算图的正确性。此外,hook函数只会在计算图的正向传播和反向传播时被调用,而不会在张量被直接使用时被调用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木盏

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值