model.train()&&model.eval()&&with torch.no_grad()用法

1.model.train()

启用 Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

用法:注意model.train()的位置


    for epoch in range(num_epoches):
        running_loss = torch.Tensor([0]).to(device)
        pbar = tqdm(enumerate(trainloader), total=len(trainloader),position=0)
        model.train()
        for i, data in pbar:
            path_img, img, labels = data
            img, labels = img.to(device), labels.to(device)
            optimizer.zero_grad()
            out = model(img)
            c = out.size()
            loss = criterion(out, labels.long())
            loss.backward()
            optimizer.step()

​

model.train()一定要是在epoch中并且在dataloader循环前的位置上

2.model.eval

不启用 Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
 

    model.eval()
    with torch.no_grad():
        pbar = tqdm(enumerate(testloader), total=len(testloader),position=0)
        for i, data in pbar:
            path_img, img, labels = data
            img, labels = img.to(torch.float32).to(device),
labels.to(torch.float32).to(device)
            d = img.size()
            e = labels.size()
            out = model(img)
            f = out.size()
            _, pred = torch.max(out.data, 1)

3.with torch.no_grad()

with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

注意:在测试模型时候,记得在with torch.no_grad()前面加上model.eval()   ,不加model.eval() 会导致测试精度很低

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值