pytorch禁止/允许计算局部梯度

一、禁止计算局部梯度

torch.autogard.no_grad: 禁用梯度计算的上下文管理器。

当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True

两种禁用方法:

  • 将不用计算梯度的变量放在with torch.no_grad()里
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
Out[12]:False

  • 使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
Out[13]:False

二、禁止后允许计算局部梯度

torch.autogard.enable_grad :允许计算梯度的上下文管理器

在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.

用法和上面类似:

  • 使用with torch.enable_grad()允许计算梯度
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...   with torch.enable_grad():
...     y = x * 2
>>> y.requires_grad
Out[14]:True

>>> y.backward()  # 计算梯度
>>> x.grad
Out[15]: tensor([2.])
  • 在禁止计算梯度下调用被允许计算梯度的函数,结果可以计算梯度
>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2

>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad

Out[16]:True

三、是否计算梯度

torch.autograd.set_grad_enable()

可以作为一个函数使用:

>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
Out[17]:False

>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
Out[18]:True

>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
Out[19]:False

总结:

单独使用这三个函数时没有什么,但是若是嵌套,遵循就近原则。

x = torch.tensor([1.], requires_grad=True)

with torch.enable_grad():
    torch.set_grad_enabled(False)
    y = x * 2
    print(y.requires_grad)
Out[20]: False

torch.set_grad_enabled(True)
with torch.no_grad():
    z = x * 2
    print(z.requires_grad)
Out[21]:False

参考:

https://pytorch.org/docs/stable/autograd.html

  • 16
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值