pytorch nn.Module train和eval 函数 深入解析

1、先看看官方的解释

请添加图片描述请添加图片描述
模型调用eval() 就是设置为评估模式,调用train(True)就为训练模式,这个说的很不透彻,我们一起来探究究竟为何model模型就设置为了评估模式

#看看train的源码:
    def train(self: T, mode: bool = True) -> T:   
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for module in self.children():#这个很关键,让它的子模块里的training值也改变,才会随之nn.DropOut()的训练模式
            module.train(mode)
        return self

2、深度解析

先说结论:模型类继承了nn.Module 就有实例属性training。模型调用train() 【参数为mode,默认值为True】 会设置training值等于mode值。调用eval() 【没有参数】实际执行会设置training值为False,等同于train(False)。
而最后 training值会影响Dropout和BatchNorm的函数参数值的设置【使用或不使用】,一般的train(True)模式,使用Dropout和BatchNorm,而eval() Dropout和BatchNorm则不会"工作"。

#这里是nn.Module 的部分源码,可见self.training 的值默认为True
class Module:
    dump_patches: bool = False
    _version: int = 1
    training: bool
    _is_full_backward_hook: Optional[bool]
    def __init__(self):
        torch._C._log_api_usage_once("python.nn_module")
        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._is_full_backward_hook = None
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
#这是Dropout2d的源码,可见training值对它的影响
class Dropout2d(_DropoutNd):
	def forward(self, input: Tensor) -> Tensor:
	        return F.dropout2d(input, self.p, self.training, self.inplace)

#这里是BatchNorm2d(继承自nn.Module)的节选,可见,self.training为True,则工作。而self.training为False也并不会不工作,而是看均值和方差是否之前计算出。

 			if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
 		if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)
  • 13
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值