1、先看看官方的解释
模型调用eval() 就是设置为评估模式,调用train(True)就为训练模式,这个说的很不透彻,我们一起来探究究竟为何model模型就设置为了评估模式
#看看train的源码:
def train(self: T, mode: bool = True) -> T:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():#这个很关键,让它的子模块里的training值也改变,才会随之nn.DropOut()的训练模式
module.train(mode)
return self
2、深度解析
先说结论:模型类继承了nn.Module 就有实例属性training。模型调用train() 【参数为mode,默认值为True】 会设置training值等于mode值。调用eval() 【没有参数】实际执行会设置training值为False,等同于train(False)。
而最后 training值会影响Dropout和BatchNorm的函数参数值的设置【使用或不使用】,一般的train(True)模式,使用Dropout和BatchNorm,而eval() Dropout和BatchNorm则不会"工作"。
#这里是nn.Module 的部分源码,可见self.training 的值默认为True
class Module:
dump_patches: bool = False
_version: int = 1
training: bool
_is_full_backward_hook: Optional[bool]
def __init__(self):
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
#这是Dropout2d的源码,可见training值对它的影响
class Dropout2d(_DropoutNd):
def forward(self, input: Tensor) -> Tensor:
return F.dropout2d(input, self.p, self.training, self.inplace)
#这里是BatchNorm2d(继承自nn.Module)的节选,可见,self.training为True,则工作。而self.training为False也并不会不工作,而是看均值和方差是否之前计算出。
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)