register_hook

1.Pytorch中autograd以及hook函数详解

https://oldpan.me/archives/pytorch-autograd-hook

2.知乎知识内容

https://www.zhihu.com/question/61044004

作者:李斌
链接:https://www.zhihu.com/question/61044004/answer/183682138
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

首先明确一点,有哪些hook?

我看到的有3个:

1. torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package

2. torch.nn.Module.register_backward_hook (Python method, in torch.nn)

3. torch.nn.Module.register_forward_hook

第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。

 

其次,明确一下,为什么需要用hook

打个比方,有这么个函数, [公式][公式][公式] 你想通过梯度下降法求最小值。在PyTorch里面很容易实现,你只需要:

import torch
from torch.autograd import Variable

x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
z.backward()
x.data -= lr*x.grad.data

但问题是,如果我想要求中间变量 [公式]的梯度,系统会返回错误。

事实上,如果你输入:

type(y.grad)

系统会告诉你:NoneType

这个问题在PyTorch的论坛上有人提问过,开发者说是因为当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉。

 

因此,hook就派上用场了。简而言之,register_hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。

import torch
from torch.autograd import Variable

grad_list = []

def print_grad(grad):
    grad_list.append(grad)

x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data

需要注意的是,register_hook函数接收的是一个函数,这个函数有如下的形式:

hook(grad) -> Variable or None

也就是说,这个函数是拥有改变梯度值的威力的!

 

至于register_forward_hook和register_backward_hook的用法和这个大同小异。只不过对象从Variable改成了你自己定义的nn.Module。

当你训练一个网络,想要提取中间层的参数、或者特征图的时候,使用hook就能派上用场了。

 

参考资料:

1. Why cant I see .grad of an intermediate variable?

2. Extract feature maps from intermediate layers without modifying forward()

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值