我的 torch 版本: 1.8.1+cu111
我的 paddle 版本: 2.4.1
torch API位置
torch.is_grad_enabled()
torch.set_grad_enabled(mode)
torch.no_grad()
paddle API位置
paddle.is_grad_enabled()
paddle.set_grad_enabled(mode)
paddle.no_grad()
二者的API在当前版本是对齐的,本文是说以下二者的区别
with torch.no_grad():
<code>
和
torch.set_grad_enabled(False)
<code>
torch.set_grad_enabled(True)
Actually no, there no difference in the way used in the question. When you take a look at the source code of no_grad. You see that it is actually using torch.set_grad_enabled to archive this behaviour:
事实上,在本问题中二者无区别。我们来看下
no_grad
的源码,其内部就是借助torch.set_grad_enabled
去实现的:
class no_grad(object):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
Also functions as a decorator.
Example::
>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
"""
def __init__(self):
self.prev = torch.is_grad_enabled()
def __enter__(self):
torch._C.set_grad_enabled(False)
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False
def __call__(self, func):
@functools.wraps(func)
def decorate_no_grad(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_no_grad
通常情况下,如果要写一个上下文管理器,需要定义一个类,里面包含一个 __enter__()
和一个 __exit__()
方法
也就是执行 with
内部的程序体之前会执行 __enter__
内部代码,with
内部的程序执行之后,会执行__exit__
内部代码
看 __enter__
中先设置梯度 set_grad_enabled
为 False
def __enter__(self):
torch._C.set_grad_enabled(False)
而 __exit__
中设置之前的梯度设置
def __exit__(self, *args):
torch.set_grad_enabled(self.prev)
return False
在 __init__
中保留了之前的状态:
self.prev = torch.is_grad_enabled()
其余可以参考:
https://discuss.pytorch.org/t/difference-between-set-grad-enabled-false-and-no-grad/72038
https://stackoverflow.com/questions/53447345/pytorch-set-grad-enabledfalse-vs-with-no-grad