model.train()、model.eval()和torch.no_grad()以及detach()区别

目录

一、model.train()、model.eval()区别

1.1、model.train()函数

1.2、model.eval()函数

二、torch.no_grad()、以及detach()

2.1、detach函数

2.2、torch.no_grad函数


一、model.train()、model.eval()区别

model.train()和model.eval()函数主要用于将模型中的training属性设置为True或False两种状态,training属性会直接影响BatchNorm层(链接)和Dropout层的运行机理。

1.1、model.train()函数

调用model.train()函数将model.training属性变成True状态,BN层对数据进行标准化处理使用的均值和方差是当前批次数据所求,运用移动平均(moving average)策略更新running_mean和running_var。Dropout层失活一部分神经元。

1.2、model.eval()函数

调用model.eval()函数将model.training属性变成False状态,BN层对样本数据标准化使用整个训练数据的running_mean和running_var。BN层不进行反向传播和梯度更新,running_mean和running_var保持不变。Dropout层使用所有神经元。

二、torch.no_grad()、以及detach()

2.1、detach函数

detach函数生成与原始张量共享数据的新张量,将新张量的requires_grad标志设置成False。张量不参与计算图的梯度计算。detach函数一次只能用于一个变量速度较慢。

a = torch.tensor([1.1],requires_grad=True)
print(a.requires_grad)    #True
a = a.detach()
print(a.requires_grad)    #False


2.2、torch.no_grad函数

torch.no_grad函数与detach函数使用效果一样,放入上下文管理器中的操作不构建计算图,节省内存和显存消耗。torch.no_grad函数等价于@torch.no_grad()

x = torch.randn(3,2,requires_grad=True)
w = torch.tensor([1.1,2.2])
b = torch.ones(3)

z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

@torch.no_grad()
def fun(x,w,b):
    return torch.matmul(x, w)+b
z = fun(x,w,b)
print(z.requires_grad)

##############
True
False
False

参考博文:

pytorch中model.train()和model.eval()的区别_想念@思恋的博客-CSDN博客

pytroch:model.train()、model.eval()的使用_model.train和model.eval_像风一样自由的小周的博客-CSDN博客

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值