Pytorch model.train()

1.前言

在使用Pytorch进行模型的训练和测试时,我们总能在训练部分的最前面看到model.train(),在测试部分最前面看到model.eval()。这两种语法起到什么作用呢?

BNDropout 的介绍,可参考 Dropout & Batch Normolization_长命百岁️的博客-CSDN博客

2.作用及原因

主要是对Batch NormalizationDropout 层有影响。因为这两层在训练和测试时进行的操作是不同的。

2.1.Batch Normalization

2.1.1训练时的BN层

在训练时,我们对每个 batch 的数据,求取其均值和方差,然后进行归一化。如下图所示

在这里插入图片描述

一个 batch 的数据量是比较小的,所以按 batch 处理是比较方便的。

我们运行 model.train() 之后,就告诉了 BN 层,对之后输入的每个 batch 独立计算其均值和方差,BN 层的参数是在不断变化的。

2.1.2测试时的BN层

我们可以进行深度学习的目的就是:利用训练数据来预测测试数据的结果。其实我们默认了,训练数据和测试数据是属于相同分布的。

一个简单的例子:线性回归是默认测试数据落在,使用训练数据分布得到的直线上的

那么在测试时,最理想的状态当然是在测试的时候采用整个training data的平均值和标准差。但是训练集可能比较大,去计算均值和方差并不方便。也有可能training data是分batch进入的,很可能并没有留下来,那根本没办法计算平均值和方差。

  • 其实首先会想到计算测试数据的均值和方差,但是一般测试数据也比较多,计算量大

可行的方法:将训练过程中每个 batch μ \mu μ σ \sigma σ 都保存下来,然后加权平均当做整个训练数据集的 μ \mu μ σ \sigma σ ,同时用于测试。

model.eval() 就是告诉 BN 层,我现在要测试了,你用刚刚统计的 μ \mu μ σ \sigma σ 来测试我,不要再变了。

Pytorch 中的 BN

CLASS torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  • momentum 就是权重值。
  • 假设 xtest 时的均值或标准差
    • 第一轮的 x 就是这一轮训练数据计算出来的值
    • 之后的每一轮,假设 x i x_i xi 为这一轮训练数据计算出来的值,我们按照 x = ( 1 − m o m e n t u n ) ∗ x + m o m e n t u m ∗ x i x = (1-momentun) * x + momentum * x_i x=(1momentun)x+momentumxi 来更新,这就是训练阶段的加权过程

2.2.Dropout

这个参照上面就比较好理解了,就是区分出 traintest。我们都知道,train 时,Dropout 会遮住一些神经元,从而增强模型的泛化能力。但是在 test 时,不应该遮住神经元,而应该使用整个网络来测试。

model.train() 就是告诉 Dropout 层,你下面应该遮住一神经元

model.test() 就是告诉 Dropout 层,你下面别遮住了,我全都需要

3.总结

  • model.train() 源码

        def train(self, mode=True):
            r"""Sets the module in training mode.
    
            This has any effect only on certain modules. See documentations of
            particular modules for details of their behaviors in training/evaluation
            mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
            etc.
    
            Returns:
                Module: self
            """
            self.training = mode
            for module in self.children():
                module.train(mode)
            return self
    
  • model.eval() 源码

    def eval(self):
        return self.train(False)
    

两个函数通过改变self.training = True / False 来告知一些特定的层(BN,Dropout),应该启用 train 时的功能还是 test 时的功能。

  • 39
    点赞
  • 94
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

长命百岁️

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值