什么时候该用with torch.no_grad()?什么时候该用.requires_grad ==False?

一个torch基础问题,闲来无事想写写。
无论是否使用with torch.no_grad()还是.requires_grad == False,一般来说是不会影响算法本身的,但是会影响代码性能。

with torch.no_grad()

在这个下面进行运算得到的tensor没有grad_fn,也就是它不带梯度(因为没有上一级的函数),因此loss无法从这些tensor向上传递,产生这些tensor的网络的参数将不会更新。下面这种情况一般使用with torch.no_grad():
在这里插入图片描述
这里我们只是使用了net2的输出来计算loss,而不想让loss去更新net2的网络参数,于是使用with torch.no_grad(),这样loss就被阻断了在loss.backward过程中,而net1却正常计算网络参数梯度。如果没有使用with torch.no_grad(),也无妨,只是对net2的参数费时地计算了梯度,但是在optimizer.step的时候只有net1的参数step了。另外这也解释了为何更新前要optimizer.zero_grad,如果你像上面说的那样没有no_grad,net2的网络参数有梯度,在之后有backward了一次loss,将导致梯度叠加,也许这不是我们想要的结果(当然也有这样叠加梯度的,往往需要retain_graph==True)。

.requires_grad == False

下面这个例子可以使用:
在这里插入图片描述
这里我们只想通过loss更新net1,net2不想更新,还能通过with torch.no_grad()实现吗?答案是否定的,一旦使用就阻断了loss流动,那怎么办?如下:

for p in net2.parameters():
    p.requires_grad = False

这样不在去计算net2的网络权重w的梯度,而只是使用它的值去计算net1的梯度,提高了代码性能。
值得注意的是:.requires_grad的设置只能针对叶子结点(网络的权重w就算,bias也算),如何理解叶子结点呢?叶子节点在数据结构中是一棵树没有子节点的结点,在网络中就是这个结点对应的参数不是由更上一层的tensor计算而来。

关于retain_graph == True

这篇blog讲得很好。

  • 58
    点赞
  • 123
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

iπ弟弟

如果可以的话,请杯咖啡吧!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值