几行代码让你搞懂torch.no_grad

几行代码让你搞懂torch.no_grad

一篇讲得比较好的文章

首先明确一点,no_grad与detach有异曲同工之妙,都是逃避autograd的追踪。

接下来我们做个实验:

a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a = a + a.grad
    print(a, a.grad, a.requires_grad )
    # a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.]) None False
tensor([3., 3.]) None False

我们在with torch.nograd()下使用了 =+的操作,这实际上生成了一个新的变量a,因为torch.no_grad的作用下使得a变量没法求梯度。

如果使用+=的操作:

a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a += a.grad
    print(a, a.grad, a.requires_grad )
    # a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
---------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([4., 4.]) True

可以发现,+=的原地修改本来是不行的,因为autograd会检测你这个值是否变化,但是如果加上torch.no_grad()后就逃避了autograd的检测,在上下文管理器中只修改了tensor的data,属性没有修改,这样的话就可以对a进行求梯度的了,但是我们发现这个梯度被累加了,本来想要第二次反向传播的时候,最后a的输出不包含上一次的梯度。假定我在做一个梯度的更新操作,这个梯度累计越来越大,更新的步长越来越大,loss直接跑飞。所以得加一个梯度清零的操作。

a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
    a += a.grad
    print(a, a.grad, a.requires_grad )
    a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-----------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True

这种过程,就相当于梯度的更新了,在完成原地修改的时候能不被autograd检测到,就是torch.no_grad的一种使用场景。

接下来,就是no_grad的其它作用了,这种在本文的首页链接中可以仔细体会。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值