PyTorch 的 Module 类中为模型设定了两种模式:training 和 evaluation
以下几点需要注意,假设定义了一个继承自 torch.nn.Module 的模型 model:
(1)表征 model 模式的是继承自 torch.nn.Module 的一个布尔变量 training,默认为 True;
(2)可使用 model.train() 或 model.train(True) 进入 training 模式;可使用 model.eval() 或 self.train(False) 进入 evaluation 模式;
(3)model.train() 或 model.train(True) 会将当前 model 及其子模块的 training 设为 True,如果传入参数 False,则会将当前模块及其子模块的 training 参数均设为 False;
eval() 则直接会调用 self.train(False) 来表征模式的变化;
(4)两种模式的设定仅对特定的 module 起作用,比如 Dropout,BatchNorm 等,比如在 torch.nn.Dropout 中,如果 training = True,则会执行 dropout 操作,否则不会进行 dropout;