Crystal的博客
Pytorch中model.train()和model.eval()
问题:
刚开始接触pytorch时,发现别人的代码中,会在训练模型的一开始写上model.train(),对应的,在测试模型一开始写上model.eval()。我尝试不使用这两句,发现程序仍然能够正常运行,所以就非常好奇这两句有什么作用,为什么要这么写。
解答:
(1) 用法:这两个方法是针对在模型训练和评估时采用不同的方式的情况。如果模型中有BN层(Batch Normalization)和正则化Dropout,需要在训练模型的一开始添加model.train(),在测试模型的一开始添加model.eval()。
(2) 原因: model.train()保证BN层使用每一批数据的均值和方差,而model.eval()保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机选取一部分网络连接来训练模型,更新参数,而model.eval()使用了所有网络连接。
补充:
BN针对网络中的每一层进行归一化处理,训练时是分批的,而测试的时候针对的是全部数据。
Dropout能够以一定的概率激活神经元,忽略网络中的一些神经元,因此能够减少过拟合的现象。