Module类的构造函数:
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
其中training属性表示BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。
对于一些含有BatchNorm,Dropout等层的模型,在训练和验证时使用的forward在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。使用module.train()和module.eval()进行使用,其中这两个方法的实现均有training属性实现。
关于这两个方法的定义源码如下:
train():
def train(self, mode=True):
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
self.training = mode
for module in self.children():
module.train(mode)
return self
eval():
def eval(self):
r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
Returns:
Module: self
"""
return self.train(False)
从源码中可以看出,train和eval方法将本层及子层的training属性同时设为true或false。
具体如下:
net.train() # 将本层及子层的training设定为True
net.eval() # 将本层及子层的training设定为False
net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
net.training, net.submodel1.training
关于train(),eval()函数可以查看pytorch中的model.train()和model.eval()