显存不够用怎么办 —— 梯度累积

原文可直接运行代码

OOM 的困扰

经常有小伙伴遇到 OOM 报错以后来问 Dave 怎么办,如果遭受下面这些报错困扰那就是了。

在这里插入图片描述

内存溢出(Out-Of-Memory) 是计算机操作中的一种通常不希望遇到的状态,在这种状态下,无法分配额外的内存以供程序使用。这样状态下的系统将无法加载任何其他程序,并且由于许多程序可能在执行期间将额外的数据加载到内存中,因此这些程序将停止正常运行。

在神经网络的训练中,经常会出现图像尺寸很大又想增大batch size,无奈显存不足,但是大显存的 A6000 又挺贵的💰,那怎么办呢?

人民币小伙伴:没事,我有💰啊?

Dave:冒犯了冒犯了。。。

因为 Dave 家境贫寒,所以经常想用 GTX 3080 跑出 GTX 3090 的效果,除了之前 Dave 介绍过的 半精度训练以外 还有一个方法一直在用,推荐给小伙伴们 —— 梯度累加(积)

什么是梯度累加(积)

梯度累加是对多个批次的训练梯度进行累计,然后同时执行权重更新。这样的好处是可以只用一个批次占用的 GPU 显存,达到多个批次数量相加的 batch-size。

在这里插入图片描述

!pip install seaborn tqdm
import os
import timm
import torch
import random

from torch.cuda.amp import GradScaler
from torch.cuda.amp import autocast

import torchvision.transforms as transforms
import torchvision.datasets as dataset
import seaborn as sns
import numpy as np

from tqdm import tqdm

def seed_torch(seed=99):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
seed_torch()

准备 MNIST 数据

trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
train_set = dataset.MNIST(root='./', train=True, transform=trans, download=True)
test_set = dataset.MNIST(root='./', train=False, transform=trans)

正常训练(半精度)

# Dataloader
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=32768,
                 num_workers=8,
                 shuffle=True)

model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

losses = []
scaler = GradScaler()

for epoch in range(10):
    loss_epoch = 0
    for i, (input, target) in enumerate(train_loader):
        with autocast():
            output = model(input.cuda())
            loss = loss_fn(output, target.cuda())

        scaler.scale(loss).backward()
        loss_epoch += loss.item() * len(input)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
    print('Loss: ', loss_epoch/len(train_set))
    losses.append(loss_epoch/len(train_set))

sns.relplot(kind="line",data=losses);

在这里插入图片描述

!nvidia-smi

在这里插入图片描述
正常训练占用 30 GB 显存

训练中加入梯度累加

  • batchsize 选择 8192
  • 梯度累加策略:每 4 个 iteration 进行一次模型更新(8192 * 4 = 32768)
  • 将批次敏感的 Batch Norm 更换成 Group Norm
train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=8192,
                 num_workers=8,
                 shuffle=True)
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=10).cuda()
losses = []

# 将模型的 BatchNorm2d 更换成 GroupNorm
def convert_bn_to_gn(model):
    for child_name, child in model.named_children():
        if isinstance(child, torch.nn.BatchNorm2d):
            num_features = child.num_features
            setattr(model, child_name, torch.nn.GroupNorm(num_groups=1, num_channels=num_features))
        else:
            convert_bn_to_gn(child)

convert_bn_to_gn(model)
model = model.cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()

scaler = GradScaler()
iters_to_accumulate = 4

for epoch in range(10):
    loss_epoch = 0
    for i, (input, target) in enumerate(train_loader):
        with autocast():
            output = model(input.cuda())
            loss = loss_fn(output, target.cuda())
            loss_epoch += loss.item() * len(input)
            loss = loss / iters_to_accumulate

        scaler.scale(loss).backward()
        if (i + 1) % iters_to_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
    print('Loss: ', loss_epoch/len(train_set))
    losses.append(loss_epoch/len(train_set))

sns.relplot(kind="line",data=losses);

在这里插入图片描述

!nvidia-smi

在这里插入图片描述

占用显存明显降低到 13.7 GB左右,同时可以看到对 Loss 影响并不大。

原文可直接运行代码

在这里插入图片描述

image

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dave 扫地工

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值