一、 torch.no_grad()
方法1:使用with torch.no_grad()控制requires_grad属性的作用域。
import torch
x = torch.ones(2,2,requires_grad=True)
y = x**2
print(x.requires_grad)
with torch.no_grad():
z = x**2
print(z.requires_grad)
b = x**2
print(b.requires_grad)
输出:
True
False
True
方法2:使用 @ 装饰器和 torch.no_grad()控制requires_grad作用域,限制函数级别的运算梯度,使张量的requires_grad失效
import torch
@torch.no_grad() # 装饰器,无梯度
def doubler(x):
return x*2
x = torch.ones(2,2,requires_grad=True)
y = doubler(x)
print(y.requires_grad)
def doubler2(x): # 不带装饰器,有梯度
return x*2
z = doubler2(x)
print(z.requires_grad)
输出:
False
True
二、torch.enable_grad()
方法1:使用with torch.enable_grad()控制requires_grad作用域
import torch
x = torch.ones(2,2,requires_grad=True)
with torch.no_grad():
with torch.enable_grad():
z = x**2
print(z.requires_grad)
方法2: 使用装饰器 @torch.enable_grad()装饰函数后,该函数不再受with torch.no_grad()的影响
import torch
@torch.enable_grad()
def doubler(x):
return x*2
x = torch.ones(2,2,requires_grad=True)
with torch.no_grad():
y = doubler(x)
print(y.requires_grad)
输出:
True
翻译自:李金洪老师的基于Bert的语言处理实战