【PyTorch】with torch.no_grad() 详解

在阅读以下内容前,请务必先大致了解计算图机制,特别是叶子节点:pytorch——计算图与动态图机制

with torch.no_grad()PyTorch官网 中的定义为:

Context-manager that disables gradient calculation.

意思是with torch.no_grad()是一个用于禁用梯度的上下文管理器。禁用梯度计算对于推理是很有用的,当我们确定不会调用 Tensor.backward()时,它将减少计算的内存消耗。因为在此模式下,即使输入为 requires_grad=True,每次计算的结果也将具有requires_grad=False

总的来说, with torch.no_grad() 可以理解为,在管理器外产生的与原参数有关联的参数requires_grad属性都默认为True,而在该管理器内新产生的参数的requires_grad属性都将置为False


除此之外,with torch.no_grad()还通常与原地操作(in-place operation)组合在一起。原地操作有明确定义:

对于 requires_grad=True 的叶子张量(leaf tensor)不能使用 inplace operation

因为原地操作会覆盖当前内存的值,但叶子节点所指向的内存块进行无法进行修改操作,否则会导致其中梯度信息与节点的值不再有计算上的对应关系。

 with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_() # 清空当前梯度

于是我们针对以上操作进行探究,以更好理解该情况下with torch.no_grad()的作用。

  1. 不使用 with torch.no_grad() 进行原地操作
for param in params:
	param -= lr * param.grad / batch_size
	param.grad.zero_() # 清空当前梯度

运行上面的代码会报错,错误信息为RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. 意思是在原地操作中使用了需要梯度的叶子节点。

如果你有意验证有无 with torch.no_grad() 进行原地操作的两种情况下 param 的 requires_grad 属性,你会发现其值都为True。那么可能有人会有疑问,影响原地操作的定义不就是 requires_grad 属性吗。那么你需要做的相信定义,并理解以下两层:

  1. lr * param.grad / batch_size 会创建一块临时内存,这块临时内存的 requires_grad 属性是 False
  2. param.grad 也会占用一块内存,其也具有 requires_grad 属性,且为 False
  1. 不使用 with torch.no_grad() 进行赋值操作
for param in params:
	param = param - lr * param.grad / batch_size
	print(param.is_leaf) # False
	param.grad.zero_() # 清空当前梯度

运行上面的代码会报错,错误信息为AttributeError: 'NoneType' object has no attribute 'zero_'。我们都知道赋值操作会新创建一块内存以存放数据,所以根据计算图理论,此时的param是中间节点,不再是叶子节点,不具有grad属性了。

  1. 使用 with torch.no_grad() 进行赋值操作
with torch.no_grad():
	for param in params:
		print(param.requires_grad) # True
		param = param - lr * param.grad / batch_size
		print(param.is_leaf) # True
		print(param.requires_grad) # Flase
		# param.requires_grad = True
		param.grad.zero_() # 清空当前梯度

运行上面的代码会报错,错误信息为AttributeError: 'NoneType' object has no attribute 'zero_'

我们知道在 PyTorch 中,前向传播过程中构建计算图,而反向传播时销毁计算图以释放内存并计算叶子节点的梯度信息。当我们使用 with torch.no_grad() 上下文管理器时,我们指示 PyTorch 在此上下文中不跟踪梯度信息,因此不会构建用于反向传播的计算图。尽管如此,由于在 torch.no_grad() 上下文中创建的张量(如 param)不依赖于计算图中的其他节点,它们仍然被视为叶子节点。因此,这些张量的梯度信息仍然可以被访问,但是梯度计算不会在该上下文中进行,因此在此上下文内产生的张量不会保存任何梯度信息。

with torch.no_grad()上下文中,param仍然是叶子节点。但是赋值操作会创建一个新的张量,并且这个新的张量中的requires_grad = False。理论上,我们可以将requires_grad重新设置为True,然后再进行反向传播,但这样做非常麻烦且没有意义,且会导致大量的内存占用。因此通常不建议这样做。如果想要尝试,可以修改函数结构,将操作集合到一个函数内共享参数,以确保梯度追踪和梯度计算的一致性。

  • 9
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值