pytorch中的model.eval() 和model.train()以及with torch.no_grad 还有torch.set_grad_enabled总结

一、pytorch中的model.eval() 和 model.train()

再pytorch中我们可以使用eval和train来控制模型是出于验证还是训练模式,那么两者对网络模型的具体影响是什么呢?

1. model.eval()

eval主要是用来影响网络中的dropout层和batchnorm层的行为。在dropout层保留所有的神经网络单元,batchnorm层使用在训练阶段学习得到的mean和var值。另外eval不会影响网络参数的梯度的计算,只不过不回传更新参数而已。所以eval模式要比with torch.no_grad更费时间和显存。

2. model.train()

这个就是训练模式,是网络的默认模式。在这个模式下,dropout层会按照设置好的失活概率进行失活,batchnorm会继续计算数据的均值和方差等参数并在每个batch size之间不断更新。

参考:model.train() and model.eval() do not change any behavior of the gradient calculations, but are used to set specific layers like dropout and batchnorm to evaluation mode (dropout won’t drop activations, batchnorm will use running estimates instead of batch statistics).

3. with troch.no_grad

with torch.no_grad会影响网络的自动求导机制,也就是网络前向传播后不会进行求导和进行反向传播。另外他不会影响dropout层和batchnorm层。
对dropout进行验证:

import torch
import torch.nn as nn

drop = nn.Dropout()

x = torch.ones(1, 10)

print(x)
drop.train()
print(drop(x))

drop.eval()
print(drop(x))

drop.train()
print(drop(x))

with torch.no_grad():
    print(drop(x))

结果:

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 2., 2., 2., 2., 0., 2., 2., 2., 0.]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0., 0., 0., 2., 0., 2., 2., 2., 0., 0.]])
tensor([[0., 0., 2., 0., 2., 2., 0., 2., 0., 0.]])

此结果也就表明with torch.no_grad() 对dropout不会产生影响。

4. torch.set_grad_enabled(mode)

与with troch.no_grad 相似,会将在这个with包裹下的所有的计算出的 新的变量 的required_grad 置为false。但原有的变量required_grad 不会改变。这实际上也就是影响了网络的自动求导机制。与with torch.no_grad() 相似,不过接受一个bool类型的值。

参考:(https://discuss.pytorch.org/t/confused-about-set-grad-enabled/38417)

  1. model.train() and model.eval() change the behavior of some layers. E.g. nn.Dropout won’t drop anymore and nn.BatchNorm layers will use the running estimates instead of the batch statistics. The torch.set_grad_enabled line of code makes sure to clear the intermediate values for evaluation, which are needed to backpropagate during training, thus saving memory. It’s comparable to the with torch.no_grad() statement but takes a bool value.

  2. All new operations in the torch.set_grad_enabled(False) block won’t require gradients. However, the model parameters will still require gradients.

  • 17
    点赞
  • 74
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值