一、pytorch中的model.eval() 和 model.train()
再pytorch中我们可以使用eval和train来控制模型是出于验证还是训练模式,那么两者对网络模型的具体影响是什么呢?
1. model.eval()
eval主要是用来影响网络中的dropout层和batchnorm层的行为。在dropout层保留所有的神经网络单元,batchnorm层使用在训练阶段学习得到的mean和var值。另外eval不会影响网络参数的梯度的计算,只不过不回传更新参数而已。所以eval模式要比with torch.no_grad更费时间和显存。
2. model.train()
这个就是训练模式,是网络的默认模式。在这个模式下,dropout层会按照设置好的失活概率进行失活,batchnorm会继续计算数据的均值和方差等参数并在每个batch size之间不断更新。
参考:model.train() and model.eval() do not change any behavior of the gradient calculations, but are used to set specific layers like dropout and batchnorm to evaluation mode (dropout won’t drop activations, batchnorm will use running estimates instead of batch statistics).
3. with troch.no_grad
with torch.no_grad会影响网络的自动求导机制,也就是网络前向传播后不会进行求导和进行反向传播。另外他不会影响dropout层和batchnorm层。
对dropout进行验证:
import torch
import torch.nn as nn
drop = nn.Dropout()
x = torch.ones(1, 10)
print(x)
drop.train()
print(drop(x))
drop.eval()
print(drop(x))
drop.train()
print(drop(x))
with torch.no_grad():
print(drop(x))
结果:
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 2., 2., 2., 2., 0., 2., 2., 2., 0.]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 2., 0., 2., 2., 2., 0., 0.]])
tensor([[0., 0., 2., 0., 2., 2., 0., 2., 0., 0.]])
此结果也就表明with torch.no_grad() 对dropout不会产生影响。
4. torch.set_grad_enabled(mode)
与with troch.no_grad 相似,会将在这个with包裹下的所有的计算出的 新的变量 的required_grad 置为false。但原有的变量required_grad 不会改变。这实际上也就是影响了网络的自动求导机制。与with torch.no_grad() 相似,不过接受一个bool类型的值。
参考:(https://discuss.pytorch.org/t/confused-about-set-grad-enabled/38417)
-
model.train() and model.eval() change the behavior of some layers. E.g. nn.Dropout won’t drop anymore and nn.BatchNorm layers will use the running estimates instead of the batch statistics. The torch.set_grad_enabled line of code makes sure to clear the intermediate values for evaluation, which are needed to backpropagate during training, thus saving memory. It’s comparable to the with torch.no_grad() statement but takes a bool value.
-
All new operations in the torch.set_grad_enabled(False) block won’t require gradients. However, the model parameters will still require gradients.