小白学Pytorch系列-- Torch API (4)
上下文管理器 torch.no_grad()
、torch.enable_grad()
和torch.set_grad_enabled()
有助于在本地禁用和启用梯度计算。有关其用法的更多详细信息,请参阅本地禁用梯度计算。这些上下文管理器是线程本地的,因此如果您使用线程模块等将工作发送到另一个线程,它们将无法工作。
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
NO_GRAD
禁用渐变计算的上下文管理器。
当您确定不会调用Tensor.backward()
时,禁用梯度计算对于推理非常有用。这将减少计算的内存消耗,否则这些计算将require_grad=True
。
在这种模式下,即使输入的requires_grad
为True,每次计算的结果也将为requires_grad=False
。
此上下文管理器是线程本地的;它不会影响其他线程中的计算。
也用作装饰器。(确保用括号实例化。)
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
y = x * 2
y.requires_grad
@torch.no_grad()
def doubler(x):
return x * 2
z = doubler(x)
z.requires_grad
ENABLE_GRAD
启用梯度计算的上下文管理器。
启用梯度计算,如果它已通过 no_grad
或 set_grad_enabled
禁用。
这个上下文管理器是线程本地的;它不会影响其他线程中的计算。
也起到装饰器的作用。
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
with torch.enable_grad():
y = x * 2
y.requires_grad
y.backward()
x.grad
@torch.enable_grad()
def doubler(x):
return x * 2
with torch.no_grad():
z = doubler(x)
z.requires_grad
SET_GRAD_ENABLED
将梯度计算设置为打开或关闭的上下文管理器。
set_grad_enabled
将根据其参数模式启用或禁用梯度。它可以用作上下文管理器或函数。
这个上下文管理器是线程本地的;它不会影响其他线程中的计算。
x = torch.tensor([1.], requires_grad=True)
is_train = False
with torch.set_grad_enabled(is_train):
y = x * 2
y.requires_grad
_ = torch.set_grad_enabled(True)
y = x * 2
y.requires_grad
_ = torch.set_grad_enabled(False)
y = x * 2
y.requires_grad
TORCH.IS_GRAD_ENABLED
如果当前启用渐变模式,则返回True。
如果当前启用了梯度模式,则返回 True。
INFERENCE_MODE
启用或禁用推理模式的上下文管理器
InferenceMode 是一个类似于 no_grad 的新上下文管理器,当您确定您的操作不会与 autograd 交互时使用(例如,模型训练)。在此模式下运行的代码通过禁用视图跟踪和版本计数器颠簸获得更好的性能。请注意,与本地启用或禁用 grad 的某些其他机制不同,进入 inference_mode 也会禁用转发模式 AD。
这个上下文管理器是线程本地的;它不会影响其他线程中的计算。
也起到装饰器的作用。 (确保用括号实例化。)
import torch
x = torch.ones(1, 2, 3, requires_grad=True)
with torch.inference_mode():
y = x * x
y.requires_grad
y._version
@torch.inference_mode()
def func(x):
return x * x
out = func(x)
out.requires_grad
TORCH.IS_INFERENCE_MODE_ENABLED
如果当前启用了推理模式,则返回 True。