几行代码让你搞懂torch.no_grad
首先明确一点,no_grad与detach有异曲同工之妙,都是逃避autograd的追踪。
接下来我们做个实验:
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
a = a + a.grad
print(a, a.grad, a.requires_grad )
# a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.]) None False
tensor([3., 3.]) None False
我们在with torch.nograd()下使用了 =+的操作,这实际上生成了一个新的变量a,因为torch.no_grad的作用下使得a变量没法求梯度。
如果使用+=的操作:
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
a += a.grad
print(a, a.grad, a.requires_grad )
# a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
---------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([4., 4.]) True
可以发现,+=的原地修改本来是不行的,因为autograd会检测你这个值是否变化,但是如果加上torch.no_grad()后就逃避了autograd的检测,在上下文管理器中只修改了tensor的data,属性没有修改,这样的话就可以对a进行求梯度的了,但是我们发现这个梯度被累加了,本来想要第二次反向传播的时候,最后a的输出不包含上一次的梯度。假定我在做一个梯度的更新操作,这个梯度累计越来越大,更新的步长越来越大,loss直接跑飞。所以得加一个梯度清零的操作。
a = torch.ones(2,requires_grad=True)
b = a*2
print(a, a.grad, a.requires_grad )
b.sum().backward(retain_graph = True )
print(a, a.grad, a.requires_grad )
with torch.no_grad():
a += a.grad
print(a, a.grad, a.requires_grad )
a.grad.zero_()
b.sum().backward(retain_graph = True )
print(a, a.grad ,a.requires_grad )
-----------------------------------
tensor([1., 1.], requires_grad=True) None True
tensor([1., 1.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
tensor([3., 3.], requires_grad=True) tensor([2., 2.]) True
这种过程,就相当于梯度的更新了,在完成原地修改的时候能不被autograd检测到,就是torch.no_grad的一种使用场景。
接下来,就是no_grad的其它作用了,这种在本文的首页链接中可以仔细体会。