Pytorch中model.train()和model.eval()的作用

我们在使用pytorch训练模型的时候会先加上一句

model.train()

模型训练完做推理时,也会先加上一句

model.eval()

这两句话的作用是告诉模型当前是在训练还是推理阶段。因为我们的模型的某些部分在做训练和推理时的操作是不一样的,如BN层的计算过程,Faster RCNN在训练和推理时预选框的选择等等。
那么这两句话背后是做了什么操作来告诉模型当前阶段是训练还是推理呢?其实train()eval()方法是在torch的Module类中实现的。源码如下

class Module(object):
    _version = 1

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
	......
	......
	......
	......
	......
    def train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Args:
            mode (bool): whether to set training mode (``True``) or evaluation
                         mode (``False``). Default: ``True``.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

    def eval(self):
        r"""Sets the module in evaluation mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.

        Returns:
            Module: self
        """
        return self.train(False)

我们可以看到Module类中有定义一个参数training,并初始化为True,

self.training=True

我们自己在写模型网络时,会先继承torch.nn.Module,

class Network(nn.Module):
	"""
	"""

模型搭建完成后,先对模型进行初始化,

model=Network()

此时model就继承了torch.nn.Module,执行model.train()时,实际执行的操作是在Module的train()方法,将模型的参数training设置为True,并且每个子代Module的training设置为True。

    def train(self, mode=True):
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

执行model.eval()时,实际执行的操作是在Module的eval()方法,eval()通过调用train(),传入False的参数,将training设置为False。

    def eval(self):
        return self.train(False)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

GHZhao_GIS_RS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值