每个变量都有两个标志:requires_grad
和volatile
。它们都允许从梯度计算中精细地排除子图,并可以提高效率。
requires_grad
如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。
>>> x = Variable(torch.randn(5, 5))
>>> y = Variable(torch.randn(5, 5))
>>> z = Variable(torch.randn(5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True
这个标志特别有用,当您想要冻结部分模型时,或者您事先知道不会使用某些参数的梯度。例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad
标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
volatile
纯粹的inference模式下推荐使用volatile
,当你确定你甚至不会调用.backward()
时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile
也决定了require_grad is False
。
volatile
不同于require_grad
的传递。如果一个操作甚至只有有一个volatile
的输入,它的输出也将是volatile
。Volatility
比“不需要梯度”更容易传递——只需要一个volatile
的输入即可得到一个volatile
的输出,相对的,需要所有的输入“不需要梯度”才能得到不需要梯度的输出。使用volatile标志,您不需要更改模型参数的任何设置来用于inference。创建一个volatile
的输入就够了,这将保证不会保存中间状态。
>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad
True
>>> model(volatile_input).requires_grad
False
>>> model(volatile_input).volatile
True
>>> model(volatile_input).creator is None
True
pytorch中的 requires_grad和volatile - 牧马人夏峥 - 博客园
简单总结其用途
(1)requires_grad=Fasle时不需要更新梯度, 适用于冻结某些层的梯度;
(2)volatile=True相当于requires_grad=False,适用于推断阶段,不需要反向传播。这个现在已经取消了,使用with torch.no_grad()来替代
pytorch学习笔记(三):自动求导_u012436149的博客-CSDN博客
pytorch
的BP
过程是由一个函数决定的,loss.backward()
, 可以看到backward()
函数里并没有传要求谁的梯度。那么我们可以大胆猜测,在BP
的过程中,pytorch
是将所有影响loss
的Tensor
都求了一次梯度。**但是有时候,我们并不想求所有Tensor
的梯度。**那就要考虑如何在Backward过程中排除子图
(ie.排除没必要的梯度计算)。
如何BP
过程中排除子图? 这就用到了Tensor
中的一个参数requires_grad
为什么要排除子图
也许有人会问,梯度全部计算,不更新的话不就得了。
这样就涉及了效率的问题了,计算很多没用的梯度是浪费了很多资源的(时间,计算机内存)