torch.Tensor.register_hook()的使用举例

参考链接: torch.Tensor.register_hook()

在这里插入图片描述

文档翻译:

register_hook(hook)

Registers a backward hook.
登记注册一个反向传播的钩子函数.
The hook will be called every time a gradient with 
respect to the Tensor is computed. The hook should 
have the following signature:
这个钩子函数在每次计算张量的梯度时都会被执行.
该钩子函数具有如下的签名形式:
hook(grad) -> Tensor or None

The hook should not modify its argument, but it can 
optionally return a new gradient which will be used 
in place of grad.
该钩子函数不应该修改传给它的参数,但是用户可以可选地
返回一个新的梯度,用来在原来grad的位置上取代grad.
This function returns a handle with a method handle.remove() that removes the hook from the module.
这个钩子函数返回一个句柄handle,该句柄有一个方法handle.remove(),
该方法可以将钩子函数从模块中移除.

Example:
举例:
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient  # 将梯度加倍 
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook # 移除该钩子函数

代码实验举例1:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001EB8032D330>
>>>
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
>>> print(v.grad)
None
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([2., 4., 6.])
>>> v.grad
tensor([2., 4., 6.])
>>> h.remove()  # removes the hook
>>> v.grad.zero_()
tensor([0., 0., 0.])
>>> v.grad
tensor([0., 0., 0.])
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
tensor([1., 2., 3.])
>>>
>>>

代码实验举例2:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000259BC41D330>
>>>
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> print(v.grad)
None
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([1., 2., 3.])
>>>
>>>
>>>

关于backward()的补充:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> v = torch.tensor([7., 8., 9.], requires_grad=True)
>>> print(v.grad)
None
>>> y = 2*v**2
>>> print(v.grad)
None
>>> y.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([ 28.,  64., 108.])
>>> # alpha = [1., 2., 3.]
>>> # y' = 4 * v * alpha
>>> # 经过验证发现满足公式
>>>
>>>
>>>

代码实验举例3:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001A24219D330>
>>>
>>> def grad_hook(grad):
...     y_grad.append(grad)
...     print("执行自定义的钩子函数...")
...
>>> y_grad = list()
>>> y_grad
[]
>>> x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
>>> y = x ** 2 + 1
>>> y.register_hook(grad_hook)
<torch.utils.hooks.RemovableHandle object at 0x000001A23305DF48>
>>> z = torch.mean(y*y)
>>> print(z.grad)
None
>>> print(y.grad)
None
>>> print(x.grad)
None
>>> print(y_grad)
[]
>>>
>>> z.backward()
执行自定义的钩子函数...
>>> print(y_grad)
[tensor([[1.0000, 2.5000],
        [5.0000, 8.5000]])]
>>> print(z.grad)
None
>>> print(y.grad)
None
>>> print(x.grad)
tensor([[ 2., 10.],
        [30., 68.]])
>>> # 数学推导得出公式: z' = x^3 + x
>>> # 经过验证满足该公式
>>>
>>>
>>>

手工推导梯度:
在这里插入图片描述

代码实验举例4:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> def grad_hook(grad):
...     return grad * 20200910.0
...
>>> x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
>>> y = torch.pow(x, 2)
>>> z = torch.mean(y)
>>> h = x.register_hook(grad_hook)
>>> print(x.grad)
None
>>> z.backward(retain_graph=True)
>>> print(x.grad)
tensor([20200910., 20200910., 20200910., 20200910.])
>>>
>>>
>>> h.remove()    # removes the hook
>>> x.grad.zero_()
tensor([0., 0., 0., 0.])
>>> x.grad
tensor([0., 0., 0., 0.])
>>> z.backward()
>>>
>>> print(x.grad)
tensor([1., 1., 1., 1.])
>>>
>>>
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值