class torch.autograd.enable_grad的使用举例

参考链接: class torch.autograd.enable_grad

在这里插入图片描述在这里插入图片描述

原文及翻译:

class torch.autograd.enable_grad
类型 torch.autograd.enable_grad

Context-manager that enables gradient calculation.
可以启用梯度计算的上下文管理器.

Enables gradient calculation, if it has been disabled via no_grad or set_grad_enabled.
如果当前已经通过使用no_grad或者通过使用set_grad_enabled来禁用梯度计算,
那么可以使用该上下文管理器来启用梯度计算.

This context manager is thread local; it will not affect computation in other threads.
该上下文管理器具有线程局部性,因此不会影响其他线程上的计算.

Also functions as a decorator. (Make sure to instantiate with parenthesis.)
该上下文管理器也可以以装饰器的方式来使用.(确保使用圆括号来进行初始化.)

实验举例1:

Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000247CA778870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]], requires_grad=True)
>>> a.grad, a.requires_grad
(None, True)
>>> with torch.no_grad():
...     with torch.enable_grad():
...             b = a * 2
...
>>> b
tensor([[ 0.5648, -0.7430,  1.8176, -3.5202],
        [-0.3612,  4.1874,  2.0812, -3.5303],
        [ 2.2433,  1.6879,  0.3567,  1.3718]], grad_fn=<MulBackward0>)
>>> b.requires_grad
True
>>>
>>> b = a * 2
>>> b.requires_grad
True
>>>
>>> with torch.set_grad_enabled(False):
...     with torch.enable_grad():
...             c = a * 3
...
>>> c
tensor([[ 0.8472, -1.1145,  2.7263, -5.2804],
        [-0.5418,  6.2810,  3.1219, -5.2954],
        [ 3.3649,  2.5319,  0.5350,  2.0576]], grad_fn=<MulBackward0>)
>>> c.requires_grad
True
>>>
>>> # c.sum().backward()
>>> a.grad, a.requires_grad
(None, True)
>>>
>>> c.sum().backward()
>>> a.requires_grad
True
>>> a.grad
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
>>>
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x00000247CA778870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]], requires_grad=True)
>>> e = a * 4
>>> e.requires_grad
True
>>>
>>> torch.set_grad_enabled(False)
<torch.autograd.grad_mode.set_grad_enabled object at 0x00000247C9E44C08>
>>>
>>> f = a * 5
>>> f.requires_grad
False
>>>
>>> with torch.enable_grad():
...     g = a * 6
...
>>> g.requires_grad
True
>>>
>>>
>>>

实验举例2:

Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102

(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001E3EF408870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715,  0.9088, -1.7601],
        [-0.1806,  2.0937,  1.0406, -1.7651],
        [ 1.1216,  0.8440,  0.1783,  0.6859]], requires_grad=True)
>>> a.grad, a.requires_grad
(None, True)
>>>
>>> @torch.enable_grad()
... def multiplyN4cxq(x,N):
...     return x * N
...
>>> with torch.no_grad():
...     b = a * 10
...     c = multiplyN4cxq(a,100)
...
>>> b
tensor([[  2.8239,  -3.7148,   9.0878, -17.6012],
        [ -1.8060,  20.9368,  10.4062, -17.6514],
        [ 11.2164,   8.4397,   1.7833,   6.8588]])
>>> c
tensor([[  28.2389,  -37.1484,   90.8775, -176.0119],
        [ -18.0601,  209.3681,  104.0623, -176.5138],
        [ 112.1640,   84.3969,   17.8333,   68.5875]], grad_fn=<MulBackward0>)
>>>
>>> a.grad, a.requires_grad
(None, True)
>>> b.requires_grad
False
>>> c.requires_grad
True
>>>
>>>
>>>
>>>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值