参考链接: torch.Tensor.register_hook()
文档翻译:
register_hook(hook)
Registers a backward hook.
登记注册一个反向传播的钩子函数.
The hook will be called every time a gradient with
respect to the Tensor is computed. The hook should
have the following signature:
这个钩子函数在每次计算张量的梯度时都会被执行.
该钩子函数具有如下的签名形式:
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can
optionally return a new gradient which will be used
in place of grad.
该钩子函数不应该修改传给它的参数,但是用户可以可选地
返回一个新的梯度,用来在原来grad的位置上取代grad.
This function returns a handle with a method handle.remove() that removes the hook from the module.
这个钩子函数返回一个句柄handle,该句柄有一个方法handle.remove(),
该方法可以将钩子函数从模块中移除.
Example:
举例:
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient # 将梯度加倍
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook # 移除该钩子函数
代码实验举例1:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001EB8032D330>
>>>
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> print(v.grad)
None
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([2., 4., 6.])
>>> v.grad
tensor([2., 4., 6.])
>>> h.remove() # removes the hook
>>> v.grad.zero_()
tensor([0., 0., 0.])
>>> v.grad
tensor([0., 0., 0.])
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
tensor([1., 2., 3.])
>>>
>>>
代码实验举例2:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000259BC41D330>
>>>
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> print(v.grad)
None
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([1., 2., 3.])
>>>
>>>
>>>
关于backward()的补充:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> v = torch.tensor([7., 8., 9.], requires_grad=True)
>>> print(v.grad)
None
>>> y = 2*v**2
>>> print(v.grad)
None
>>> y.backward(torch.tensor([1., 2., 3.]))
>>> print(v.grad)
tensor([ 28., 64., 108.])
>>> # alpha = [1., 2., 3.]
>>> # y' = 4 * v * alpha
>>> # 经过验证发现满足公式
>>>
>>>
>>>
代码实验举例3:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001A24219D330>
>>>
>>> def grad_hook(grad):
... y_grad.append(grad)
... print("执行自定义的钩子函数...")
...
>>> y_grad = list()
>>> y_grad
[]
>>> x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
>>> y = x ** 2 + 1
>>> y.register_hook(grad_hook)
<torch.utils.hooks.RemovableHandle object at 0x000001A23305DF48>
>>> z = torch.mean(y*y)
>>> print(z.grad)
None
>>> print(y.grad)
None
>>> print(x.grad)
None
>>> print(y_grad)
[]
>>>
>>> z.backward()
执行自定义的钩子函数...
>>> print(y_grad)
[tensor([[1.0000, 2.5000],
[5.0000, 8.5000]])]
>>> print(z.grad)
None
>>> print(y.grad)
None
>>> print(x.grad)
tensor([[ 2., 10.],
[30., 68.]])
>>> # 数学推导得出公式: z' = x^3 + x
>>> # 经过验证满足该公式
>>>
>>>
>>>
手工推导梯度:
代码实验举例4:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> def grad_hook(grad):
... return grad * 20200910.0
...
>>> x = torch.tensor([2., 2., 2., 2.], requires_grad=True)
>>> y = torch.pow(x, 2)
>>> z = torch.mean(y)
>>> h = x.register_hook(grad_hook)
>>> print(x.grad)
None
>>> z.backward(retain_graph=True)
>>> print(x.grad)
tensor([20200910., 20200910., 20200910., 20200910.])
>>>
>>>
>>> h.remove() # removes the hook
>>> x.grad.zero_()
tensor([0., 0., 0., 0.])
>>> x.grad
tensor([0., 0., 0., 0.])
>>> z.backward()
>>>
>>> print(x.grad)
tensor([1., 1., 1., 1.])
>>>
>>>