PyTorch学习笔记:model.train()与model.eval()——切换训练模式与测试模式
功能:用于切换模型的模式,model.train()
将模型切换为训练模式,model.eval()
将模型切换为测试模式。
主要区别
BN与DropOut运算规则不同
BN层
定义:对于所有的batch中的同一个channel的数据元素进行标准化处理,即如果有C个通道,无论有多少个batch,都会在通道维度上进行标准化处理,一共进行C次,训练阶段与测试阶段均值方差的计算不同。
训练阶段:将同一下batch通道相同的值取出来,一块计算均值和方差,即计算当前观测值的均值和方差,并且利用当前数据的均值方差更新全局的均值方差。
测试阶段:利用模型存储的全局均值方差做标准化运算,并且不改变全局均值方差的数值。
注:
- 只要有数据传入BN层,即做了前向传播,则BN层中存储的全局均值和方差就会做相应的更新,无需做反向传播;
- 具体细节可见:《nn.BatchNorm2d学习笔记》。
DropOut层
定义:在训练阶段按某种概率随即将输入的张量元素随机归零,常用的正则化器,用于防止网络过拟合。
训练阶段:数据首先被放大 1 1 − p \frac1{1-p} 1−p1倍,之后再以概率 p p p执行归零操作。
测试阶段:不对数据做变化。
注:
- 具体细节可见:《nn.Dropout学习笔记》。
补充
测试模式下也可以求解参数梯度,因此也会占用一定的显存空间,如果想要丢弃梯度的运算,节省GPU算力和显存,则可以引入with torch.no_grad():
方法,如下面:
def test(model,dataloader):
model.eval() # 切换到测试模式
with torch.no_grad(): #with下内容不进行grad计算
...
官方文档
model.train():https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=train#torch.nn.Module.train
model.eval():https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.eval
torch.no_grad:https://pytorch.org/docs/stable/generated/torch.no_grad.html?highlight=torch+no_grad#torch.no_grad
注:以上仅是笔者个人见解,若有问题,欢迎指正
初步完稿于:2023年2月4日