torch中的model.eval()、model.train()详解

👨‍💻个人简介: 深度学习图像领域工作者
🎉工作总结链接:https://blog.csdn.net/qq_28949847/article/details/128552785
             链接中主要是个人工作的总结,每个链接都是一些常用demo,代码直接复制运行即可。包括:
                    📌1.工作中常用深度学习脚本
                    📌2.torch、numpy等常用函数详解
                    📌3.opencv 图片、视频等操作
                    📌4.个人工作中的项目总结(纯干活)
🎉视频讲解: 以上记录,通过B站等平台进行了视频讲解使用,可搜索 ‘Python图像识别’ 进行观看
              B站:Python图像识别
              抖音:Python图像识别
              西瓜视频:Python图像识别


记录此博客原因:
在进行车牌识别时,发现trian的模型loss很低,效果很好,但是当预测时,效果就不好,经过排查是因为模型中添加了BN层和dropout层,但是在预测时,没有添加eval()导致的,下面是详细记录。

简介

在使用pytorch训练和预测时会分别使用到以下两行代码:

model.train()
model.eval()

下面对两行代码的具体作用进行记录,

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval(),简而言之,就是评估模式,而非训练模式。
在评估模式下,batchNorm层,dropout层等网络层会被关闭,不启用BN层和dropout层,使用训练时模型中保存的参数,从而使得评估时结果不会发生偏移。

在对模型进行评估时,应该配合使用with torch.no_grad()model.eval()

...
model.eval()
with torch.no_grad():
    Evaluation
...

如果模型中有batchNorm以及dropout等层,不添加model.eval()的话,结果是不可预料的。本人在进行评估LPRNet车牌识别时,就出现了这种情况,忘记添加eval(),结果飘忽不定。

不添加eval(),结果如下:
在这里插入图片描述

添加eval(),结果如下:

在这里插入图片描述
可以看出添加eval()后,模型的准确度立刻就上去了。

2.1 总结

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

1)训练过程中BN的变化。
在训练过程中BN会不断的计算均值和方差,训练结束后得到最终的均值和方差,在此处将其记为mean_train,variance_train。

2)预测过程中BN的变化。
预测过程中如果不使用model.eval()的话,BN层还是会根据输入的预测数据继续计算均值和方差,假设输入一条预测数据后,BN层计算得到其均值和方差分别为mean_test,variance_test,此时BN层的均值和方差则变成了(mean_train+mean_test),(variance_train+variance_test),相比于训练过程中的均值和方差发生了变化因此会导致预测结果发生变化。

如果使用model.eval()则BN层就不会再计算预测数据的均值和方差,即在预测过程中BN层的均值和方差就是训练过程得到的均值和方差mean_train,variance_train,此时预测结果就不会再发生变化。

3)训练过程中Dropout的变化
训练过程中依据设置的dropout比例会使一部分的网络连接不进行计算。

4)预测过程中Dropout的变化
预测过程中如果不使用model.eval()的话,依然会使一部分的网络连接不进行计算,而使用model.eval()后就是所有的网络连接均进行计算。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Python图像识别

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值