一、禁止计算局部梯度
torch.autogard.no_grad: 禁用梯度计算的上下文管理器。
当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True
两种禁用方法:
- 将不用计算梯度的变量放在with torch.no_grad()里
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
Out[12]:False
- 使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
Out[13]:False
二、禁止后允许计算局部梯度
torch.autogard.enable_grad :允许计算梯度的上下文管理器
在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.
用法和上面类似:
- 使用with torch.enable_grad()允许计算梯度
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
Out[14]:True
>>> y.backward() # 计算梯度
>>> x.grad
Out[15]: tensor([2.])
- 在禁止计算梯度下调用被允许计算梯度的函数,结果可以计算梯度
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
Out[16]:True
三、是否计算梯度
torch.autograd.set_grad_enable()
可以作为一个函数使用:
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
Out[17]:False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
Out[18]:True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
Out[19]:False
总结:
单独使用这三个函数时没有什么,但是若是嵌套,遵循就近原则。
x = torch.tensor([1.], requires_grad=True)
with torch.enable_grad():
torch.set_grad_enabled(False)
y = x * 2
print(y.requires_grad)
Out[20]: False
torch.set_grad_enabled(True)
with torch.no_grad():
z = x * 2
print(z.requires_grad)
Out[21]:False
参考: