目录
一、model.train()、model.eval()区别
一、model.train()、model.eval()区别
model.train()和model.eval()函数主要用于将模型中的training属性设置为True或False两种状态,training属性会直接影响BatchNorm层(链接)和Dropout层的运行机理。
1.1、model.train()函数
调用model.train()函数将model.training属性变成True状态,BN层对数据进行标准化处理使用的均值和方差是当前批次数据所求,运用移动平均(moving average)策略更新running_mean和running_var。Dropout层失活一部分神经元。
1.2、model.eval()函数
调用model.eval()函数将model.training属性变成False状态,BN层对样本数据标准化使用整个训练数据的running_mean和running_var。BN层不进行反向传播和梯度更新,running_mean和running_var保持不变。Dropout层使用所有神经元。
二、torch.no_grad()、以及detach()
2.1、detach函数
detach函数生成与原始张量共享数据的新张量,将新张量的requires_grad标志设置成False。张量不参与计算图的梯度计算。detach函数一次只能用于一个变量速度较慢。
a = torch.tensor([1.1],requires_grad=True)
print(a.requires_grad) #True
a = a.detach()
print(a.requires_grad) #False
2.2、torch.no_grad函数
torch.no_grad函数与detach函数使用效果一样,放入上下文管理器中的操作不构建计算图,节省内存和显存消耗。torch.no_grad函数等价于@torch.no_grad()
x = torch.randn(3,2,requires_grad=True)
w = torch.tensor([1.1,2.2])
b = torch.ones(3)
z = torch.matmul(x, w)+b
print(z.requires_grad)
with torch.no_grad():
z = torch.matmul(x, w)+b
print(z.requires_grad)
@torch.no_grad()
def fun(x,w,b):
return torch.matmul(x, w)+b
z = fun(x,w,b)
print(z.requires_grad)
##############
True
False
False
参考博文:
pytorch中model.train()和model.eval()的区别_想念@思恋的博客-CSDN博客
pytroch:model.train()、model.eval()的使用_model.train和model.eval_像风一样自由的小周的博客-CSDN博客