[转载] set_grad_enabled(False) VS with no_grad()

我的 torch 版本: 1.8.1+cu111
我的 paddle 版本: 2.4.1

torch API位置

torch.is_grad_enabled()
torch.set_grad_enabled(mode)
torch.no_grad()

paddle API位置

paddle.is_grad_enabled()
paddle.set_grad_enabled(mode)
paddle.no_grad()

二者的API在当前版本是对齐的,本文是说以下二者的区别

with torch.no_grad():
    <code>

torch.set_grad_enabled(False)
<code>
torch.set_grad_enabled(True)

Actually no, there no difference in the way used in the question. When you take a look at the source code of no_grad. You see that it is actually using torch.set_grad_enabled to archive this behaviour:

事实上,在本问题中二者无区别。我们来看下 no_grad 的源码,其内部就是借助 torch.set_grad_enabled 去实现的:

class no_grad(object):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.
    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    Also functions as a decorator.


    Example::

        >>> x = torch.tensor([1], requires_grad=True)
        >>> with torch.no_grad():
        ...   y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """

    def __init__(self):
        self.prev = torch.is_grad_enabled()

    def __enter__(self):
        torch._C.set_grad_enabled(False)

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)
        return False

    def __call__(self, func):
        @functools.wraps(func)
        def decorate_no_grad(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return decorate_no_grad

通常情况下,如果要写一个上下文管理器,需要定义一个类,里面包含一个 __enter__() 和一个 __exit__() 方法

也就是执行 with 内部的程序体之前会执行 __enter__ 内部代码,with 内部的程序执行之后,会执行__exit__ 内部代码

__enter__ 中先设置梯度 set_grad_enabled 为 False

    def __enter__(self):
        torch._C.set_grad_enabled(False)

__exit__ 中设置之前的梯度设置

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)
        return False

__init__ 中保留了之前的状态:

self.prev = torch.is_grad_enabled()

其余可以参考:
https://discuss.pytorch.org/t/difference-between-set-grad-enabled-false-and-no-grad/72038
https://stackoverflow.com/questions/53447345/pytorch-set-grad-enabledfalse-vs-with-no-grad

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值