PyTorch torch.no_grad()

本文详细介绍了PyTorch中torch.no_grad()的作用,它用于在神经网络推理阶段禁用梯度计算,以提高效率。torch.no_grad()是一个上下文管理器,也支持作为装饰器使用,其内部实现包括进入和退出时对梯度状态的管理。此外,文章还提到了torch.enable_grad()和torch.set_grad_enabled(mode)作为替代选项,它们同样用于控制梯度计算。通过示例代码展示了torch.no_grad()的使用效果,强调了其在避免不必要的梯度计算中的重要性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.no_grad() 一般用于神经网络的推理阶段, 表示张量的计算过程中无需计算梯度


torch.no_grad 是一个类, 实现了 __enter__ 和 __exit__ 方法, 在进入环境管理器时记录梯度使能状态以及禁止梯度计算, 退出环境管理器时还原, 它还继承了 _DecoratorContextManager, 拥有装饰器的能力(依然是使用 with 语句)

# 摘自源码
class no_grad(_DecoratorContextManager):
    def __init__(self):
        self.prev = False

    def __enter__(self):
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

class _DecoratorContextManager:
    """Allow a context manager to be used as a decorator"""

    def __call__(self, func: F) -> F:
        @functools.wraps(func)
        def decorate_context(*args, **kwargs):
            with self.__class__():
                return func(*args, **kwargs)
        return cast(F, decorate_context)

    def __enter__(self) -> None:
        raise NotImplementedError

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        raise NotImplementedError

另外, torch.no_grad 用于代替旧版本的 volatile=True


import torch

x = torch.tensor([1.0], requires_grad=True)

y_1: torch.Tensor = x * x
y_1.backward()
print("y_1:", y_1.requires_grad, x.grad)

with torch.no_grad():
    y_2 = x * x
    print("y_2:", y_2.requires_grad)


@torch.no_grad()
def demo(x):
    y_3 = x * x
    print("y_3:", y_3.requires_grad)


demo(x)

打印

y_1: True tensor([2.])
y_2: False
y_3: False

y_1 是通常情况, y_1依赖于x, 而x需要求导, 所以y_1也需要求导, y_2 和 y_3 明确无需求导


除了 torch.no_grad() 还有 torch.enable_grad() 明确需要求导以及 torch.set_grad_enabled(mode), 它们均支持环境管理器和装饰器

# 单独使用 torch.set_grad_enabled
torch.set_grad_enabled(False)
y_4 = x * x
print("y_4:", y_4.requires_grad)

torch.set_grad_enabled(True)
y_5 = x * x
print("y_5:", y_5.requires_grad)

结果

y_4: False
y_5: True

底层实现位于 “aten/src/ATen/core/grad_mode.cpp”

thread_local bool GradMode_enabled = true;

bool GradMode::is_enabled() {
  return GradMode_enabled;
}

void GradMode::set_enabled(bool enabled) {
  GradMode_enabled = enabled;
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值