在模型的evaluation阶段和实际应用时,需要关闭pytorch自带的自动求导autograd机制,以防止验证/应用数据对网络参数的变动,同时减少自动求导带来的运算和存储消耗。
其常见的控制场景包括:
(1)禁止计算局部梯度
(2)在禁止计算局部梯度中,允许更精细的局部梯度计算
(3)根据判断条件,控制是否允许进行梯度更新
下面分别就上述三个场景,介绍常用的写法。
场景一:禁止计算局部梯度
pytorch提供了上下文管理器和装饰器两种方式进行控制。
# 方案一:上下文管理器
with torch.no_grad():
pass
# 方案二:装饰器
@torch.no_grad()
def tensor_func():
pass
场景二:在禁止计算局部梯度中,允许更精细的局部梯度计算
pytorch同样提供了上下文管理器和装饰器两种方式进行控制。
# 方案一:上下文管理器
with torch.no_grad(): # 禁止局部梯度
with torch.enable_grad(): # 允许局部梯度
pass
# 方案二:装饰器
@torch.no_grad()
def outer_tensor_func():
@torch.enable_grad()
def inner_tensor_func():
pass
场景三:根据判断条件,控制是否允许进行梯度更新
pytorch提供了上下文管理器的方式进行控制。
"""
参数mode为一个逻辑判断句,若为True,则会允许局部梯度;否则禁止
"""
with torch.set_grad_enabled(mode):
pass
其典型应用是将train
阶段和eval
阶段的计算过程统一写在同一个上下文管理器中,如:
with torch.set_grad_enabled(phase=='train'):
pass