Pytorch 中的 eval 模式,train 模式 和 梯度上下文管理器 torch.no_grad

前言

本文将简要说明下 Pytorch 中 model.eval() (模型评估模式),model.train() (模型训练模式) 和 torch.no_grad() (取消梯度计算上下文管理器) 的作用与用法。

model.train() 和 model.eval()

这里的 model 指的是在 Pytorch 中定义的模型,需要继承自 torch.nn.Module,下同。

model.train() 会将 model 设置成训练模式,这只会影响 model 中特定的一些模块,比如:Dropout、BatchNorm 等,因为这些模块在训练阶段和验证阶段(或测试阶段)有着不同的行为。而其他大部分模块(如,nn.Linear、nn.Embedding、nn.Conv1d 等)在训练阶段和验证以及测试阶段都具有同样的行为,所以不会受此模式的影响。

关于 Dropout 在不同阶段有着不同行为的解释可参考我的另一博文:神经网络正则化方法总结——Dropout

model.eval() 会将 model 设置成评估模式,同上,这只会影响 model 中特定的一些模块,比如:Dropout、BatchNorm 等。

若是在模型的非训练阶段(如 evaluation 阶段)未使用 model.eval()model 设置成评估模式,有可能会造成同一样本的多次推断结果不一致的情况(这可是一个很大的问题…)

torch.no_grad()

torch.no_grad 是一个上下文管理器,在其管理范围内 Pytorch 不再计算模型各参数的梯度,即使参数的 requires_grad 属性为 True,这能有效减少模型计算时所需的内存/显存(因为保存参数梯度需要大量的内存/显存)。这在模型的 evaluation 和 test 以及 predict 这些非训练阶段中很有用,常见用法有两种,一种是使用 with 语句,另一种是使用修饰器:

x = torch.tensor([1], requires_grad=True)

# 使用 with 语句
with torch.no_grad():
	y = x * 2
print(y.requires_grad) # 输出为 False

# 使用 torch.no_grad() 修饰器
@torch.no_grad()
def doubler(x):
	return x * 2
z = doubler(x)
print(z.requires_grad) # 输出为 False

在 Pytorch 模型的非训练阶段,往往需要同时使用 torch.no_grad()model.eval().

在从非训练阶段跳转到训练阶段(即 train 阶段)时,别忘了使用 mode.train() 命令。

参考源

  • 13
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值