快来优化你的Pytorch的显存占用吧

介绍

本文介绍了使用混合精度训练验证禁用梯度来优化显存的占用。根据笔者实测,混合精度训练对网络的影响几乎可以忽略不及,但是显存可以降低一半以上。

混合精度训练

from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
scaler = GradScaler()

optimizer = optim.SGD(model.parameters(), lr=0.04, momentum=0.7, weight_decay=5e-4)

for epoch in range(0, n_epochs):
    train_loss = 0.0
    valid_loss = 0.0
    model.train() 
    for data, target in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)
        with autocast():
            output = model(data).to(device) 
            loss = criterion(output, target)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


  1. 初始化梯度放大器 scaler = GradScaler()
  2. 在模型推理部分加上 with autocast()
  3. 使用Scale防止半精度发生的数据溢出!这点非常重要
  4. 遗憾的是并没有发现存在加速效果。

禁用梯度

with torch.no_grad():
    for data, target in test_loader:
        data = data.to(device)
        target = target.to(device)
        output = model(data).to(device)
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        _, pred = torch.max(output, 1)    
        correct_tensor = pred.eq(target.data.view_as(pred))
        total_sample += data.size(0)
        right_sample += list(correct_tensor).count(True)
        print()
        print("Accuracy:",100*right_sample/total_sample,"%")
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, train_loss, valid_loss))
        # 如果验证集损失函数减少,就保存模型。
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving Person ...'.format(valid_loss_min,valid_loss))
            # torch.save(model.state_dict(), 'resnet18_cifar10.pt')
            valid_loss_min = valid_loss

开启禁用梯度是因为,pytorch即便在不调用backward的情况下也会存在梯度,这导致了测试也带来了大量的显存占用,通过torch.no_grad()来禁止梯度,可以有效防止显存爆炸。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值