参考链接: class torch.autograd.enable_grad
原文及翻译:
class torch.autograd.enable_grad
类型 torch.autograd.enable_grad
Context-manager that enables gradient calculation.
可以启用梯度计算的上下文管理器.
Enables gradient calculation, if it has been disabled via no_grad or set_grad_enabled.
如果当前已经通过使用no_grad或者通过使用set_grad_enabled来禁用梯度计算,
那么可以使用该上下文管理器来启用梯度计算.
This context manager is thread local; it will not affect computation in other threads.
该上下文管理器具有线程局部性,因此不会影响其他线程上的计算.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
该上下文管理器也可以以装饰器的方式来使用.(确保使用圆括号来进行初始化.)
实验举例1:
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000247CA778870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> a.grad, a.requires_grad
(None, True)
>>> with torch.no_grad():
... with torch.enable_grad():
... b = a * 2
...
>>> b
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]], grad_fn=<MulBackward0>)
>>> b.requires_grad
True
>>>
>>> b = a * 2
>>> b.requires_grad
True
>>>
>>> with torch.set_grad_enabled(False):
... with torch.enable_grad():
... c = a * 3
...
>>> c
tensor([[ 0.8472, -1.1145, 2.7263, -5.2804],
[-0.5418, 6.2810, 3.1219, -5.2954],
[ 3.3649, 2.5319, 0.5350, 2.0576]], grad_fn=<MulBackward0>)
>>> c.requires_grad
True
>>>
>>> # c.sum().backward()
>>> a.grad, a.requires_grad
(None, True)
>>>
>>> c.sum().backward()
>>> a.requires_grad
True
>>> a.grad
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]])
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000247CA778870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> e = a * 4
>>> e.requires_grad
True
>>>
>>> torch.set_grad_enabled(False)
<torch.autograd.grad_mode.set_grad_enabled object at 0x00000247C9E44C08>
>>>
>>> f = a * 5
>>> f.requires_grad
False
>>>
>>> with torch.enable_grad():
... g = a * 6
...
>>> g.requires_grad
True
>>>
>>>
>>>
实验举例2:
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001E3EF408870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> a.grad, a.requires_grad
(None, True)
>>>
>>> @torch.enable_grad()
... def multiplyN4cxq(x,N):
... return x * N
...
>>> with torch.no_grad():
... b = a * 10
... c = multiplyN4cxq(a,100)
...
>>> b
tensor([[ 2.8239, -3.7148, 9.0878, -17.6012],
[ -1.8060, 20.9368, 10.4062, -17.6514],
[ 11.2164, 8.4397, 1.7833, 6.8588]])
>>> c
tensor([[ 28.2389, -37.1484, 90.8775, -176.0119],
[ -18.0601, 209.3681, 104.0623, -176.5138],
[ 112.1640, 84.3969, 17.8333, 68.5875]], grad_fn=<MulBackward0>)
>>>
>>> a.grad, a.requires_grad
(None, True)
>>> b.requires_grad
False
>>> c.requires_grad
True
>>>
>>>
>>>
>>>