pytorch 的 hook 机制
在看pytorch
官方文档的时候,发现在nn.Module
部分和Variable
部分均有hook
的身影。感到很神奇,因为在使用tensorflow
的时候没有碰到过这个词。所以打算一探究竟。
Variable 的 hook
register_hook(hook)
注册一个backward
钩子。
每次gradients
被计算的时候,这个hook
都被调用。hook
应该拥有以下签名:
hook(grad) -> Variable or None
hook
不应该修改它的输入,但是它可以返回一个替代当前梯度的新梯度。
这个函数返回一个 句柄(handle
)。它有一个方法 handle.remove()
,可以用这个方法将hook
从module
移除。
例子:
- 1
- 2
- 3
- 4
- 5
- 6
- 1
- 2
- 3
- 4
- 5
- 6
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
nn.Module的hook
register_forward_hook(hook)
在module
上注册一个forward hook
。
每次调用forward()
计算输出的时候,这个hook
就会被调用。它应该拥有以下签名:
hook(module, input, output) -> None
hook
不应该修改 input
和output
的值。 这个函数返回一个 句柄(handle
)。它有一个方法 handle.remove()
,可以用这个方法将hook
从module
移除。
看这个解释可能有点蒙逼,但是如果要看一下nn.Module
的源码怎么使用hook
的话,那就乌云尽散了。
先看 register_forward_hook
- 1
- 2
- 3
- 4
- 5
- 1
- 2
- 3
- 4
- 5
这个方法的作用是在此module
上注册一个hook
,函数中第一句就没必要在意了,主要看第二句,是把注册的hook
保存在_forward_hooks
字典里。
再看 nn.Module
的__call__
方法(被阉割了,只留下需要关注的部分):
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
可以看到,当我们执行model(x)
的时候,底层干了以下几件事:
-
调用
forward
方法计算结果 -
判断有没有注册
forward_hook
,有的话,就将forward
的输入及结果作为hook
的实参。然后让hook
自己干一些不可告人的事情。
看到这,我们就明白hook
签名的意思了,还有为什么hook
不能修改input
的output
的原因。
小例子:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
register_backward_hook
在module
上注册一个bachward hook
。此方法目前只能用在Module
上,不能用在Container
上,当Module
的forward函数中只有一个Function
的时候,称为Module
,如果Module
包含其它Module
,称之为Container
每次计算module
的inputs
的梯度的时候,这个hook
会被调用。hook
应该拥有下面的signature
。
hook(module, grad_input, grad_output) -> Tensor or None
如果module
有多个输入输出的话,那么grad_input
grad_output
将会是个tuple
。
hook
不应该修改它的arguments
,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代grad_input
。
这个函数返回一个 句柄(handle
)。它有一个方法 handle.remove()
,可以用这个方法将hook
从module
移除。
从上边描述来看,backward hook
似乎可以帮助我们处理一下计算完的梯度。看下面nn.Module
中register_backward_hook
方法的实现,和register_forward_hook
方法的实现几乎一样,都是用字典把注册的hook
保存起来。
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
先看个例子来看一下hook
的参数代表了什么:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
可以看出,grad_in
保存的是,此模块Function
方法的输入的值的梯度。grad_out
保存的是,此模块forward
方法返回值的梯度。我们不能在grad_in
上直接修改,但是我们可以返回一个新的new_grad_in
作为Function
方法inputs
的梯度。
上述代码对variable
和module
同时注册了backward hook
,这里要注意的是,无论是module hook
还是variable hook
,最终还是注册到Function
上的。这点通过查看Varible
的register_hook
源码和Module
的__call__
源码得知。
Module的register_backward_hook的行为在未来的几个版本可能会改变
BP过程中Function
中的动作可能是这样的
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
关于pytorch run_backward()
的可能实现猜测为。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
中间Variable的梯度在BP的过程中是保存到GradBuffer中的(C++源码中可以看到), BP完会释放. 如果retain_grads=True的话,就不会被释放