0 model.eval()
经常在模型推理代码的前面, 都会添加model.eval()
, 主要有3个作用:
- 1 不进行dropout
- 2 不更新batchnorm的mean 和var 参数
- 3 不进行梯度反向传播, 但梯度仍然会计算
1 torch.no_grad()
torch.no_grad的一般使用方法是, 在代码块外面用with torch.no_grad()
给包起来。 如下面这样:
with torch.no_grad():
# your code
它的主要作用有2个:
- 1 不进行梯度的计算(当然也就没办法反向传播了), 节约显存和算力
- 2 dropout和batchnorn还是会正常更新
2 异同
从上面的介绍中可以非常明确的看出,它们的相同点是一般都用在推理阶段, 但它们的作用是完全不同的, 也没有重叠。 可以一起使用。