【Pytorch】控制上下文局部梯度更新

在模型的evaluation阶段和实际应用时,需要关闭pytorch自带的自动求导autograd机制,以防止验证/应用数据对网络参数的变动,同时减少自动求导带来的运算和存储消耗。

其常见的控制场景包括:
(1)禁止计算局部梯度
(2)在禁止计算局部梯度中,允许更精细的局部梯度计算
(3)根据判断条件,控制是否允许进行梯度更新

下面分别就上述三个场景,介绍常用的写法。

场景一:禁止计算局部梯度

pytorch提供了上下文管理器和装饰器两种方式进行控制。

# 方案一:上下文管理器
with torch.no_grad():
	pass

# 方案二:装饰器
@torch.no_grad()
def tensor_func():
	pass
场景二:在禁止计算局部梯度中,允许更精细的局部梯度计算

pytorch同样提供了上下文管理器和装饰器两种方式进行控制。

# 方案一:上下文管理器
with torch.no_grad():    # 禁止局部梯度
	with torch.enable_grad():   # 允许局部梯度
		pass

# 方案二:装饰器
@torch.no_grad()
def outer_tensor_func():
	@torch.enable_grad()
	def inner_tensor_func():
		pass
场景三:根据判断条件,控制是否允许进行梯度更新

pytorch提供了上下文管理器的方式进行控制。

"""
参数mode为一个逻辑判断句,若为True,则会允许局部梯度;否则禁止
"""
with torch.set_grad_enabled(mode):
	pass

其典型应用是将train阶段和eval阶段的计算过程统一写在同一个上下文管理器中,如:

with torch.set_grad_enabled(phase=='train'):
	pass
  • 4
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,参数的梯度通常由autograd模块自动计算和跟踪。然而,有时候我们可能希望某些参数的梯度保持不变,即使在模型的训练期间也不更新。 为了实现这一点,我们可以使用requires_grad属性将参数的梯度开关关闭。当requires_grad属性设置为False时,参数的梯度将不会被计算和更新。 例如,假设我们有一个参数张量weights,并且我们想要保持它的梯度不变。我们可以使用如下方式: ```python import torch weights = torch.randn(3, 3, requires_grad=True) # 创建一个参数张量,开启梯度计算 # 将requires_grad设置为False,关闭梯度计算 weights.requires_grad = False # 用参数张量执行一些操作(例如前向传播和损失计算) output = weights.sum() # 进行反向传播并打印参数的梯度。由于requires_grad为False,梯度将为None output.backward() print(weights.grad) # 输出为None ``` 在上述示例中,我们首先创建一个参数张量weights,并将其requires_grad属性设置为True,以便在执行后续操作时计算梯度。然后,我们将requires_grad属性设置为False,使得参数的梯度保持不变。最后,我们进行反向传播,但由于requires_grad为False,参数的梯度将为None。 需要注意的是,关闭参数的梯度计算仅适用于当前张量的操作。例如,如果在执行forward函数时使用了该参数,在后续步骤中的gradient calculation将不会受到requires_grad属性的影响。 总而言之,通过将参数的requires_grad属性设置为False,我们可以保持参数的梯度更新,从而控制参数的训练行为。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值