一句话实现自动调整batch_size,再也不会Cuda out of memory

需要huggingface的accelerate库

核心函数(装饰器)
accelerate.find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
官方文档
(很短,小白基本看不懂,实际只需改2句话)

传统代码

以一个torch的valid_loop函数为例,valid_loop函数需要以batch_size为第一个参数(必须位置上是第一个,不能用关键字传参)

# 验证loop,验证一个epoch
def valid_loop_fn(batch_size: int, dataset, model, data_collator=None):
    model.eval()
    losses = []
    data_loader = DataLoader(dataset, batch_size=batch_size,
                             shuffle=False, collate_fn=data_collator)

    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader), desc='valid')
        for step, batch in enumerate(tk0):
            outputs = model(**batch)
            loss = outputs[0]
            losses.append(loss.cpu().item())

    # 计算平均loss
    avg_loss = np.mean(losses)
    return avg_loss

正常情况下,如果传入一个较大的batch_size例如512,就会报错

valid_losses = valid_loop_fn(512,tokenized_dataset['test'], model, data_collator)

然后RuntimeError: CUDA out of memory
然后就要手动修改batch_size,很麻烦

改进代码

此时,只需要accelerate.find_executable_batch_size对valid_loop进行装饰即可,即函数前面加一句@accelerate.find_executable_batch_size(starting_batch_size=512)
如果显存不够,会自动将batch_size减半,不会报错

import accelerate
# 验证loop,验证一个epoch
@accelerate.find_executable_batch_size(starting_batch_size=128)
def valid_loop_fn(batch_size: int, dataset, model, data_collator=None):
    model.eval()
    losses = []
    data_loader = DataLoader(dataset, batch_size=batch_size,
                             shuffle=False, collate_fn=data_collator)

    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader), desc='valid')
        for step, batch in enumerate(tk0):
            outputs = model(**batch)
            loss = outputs[0]
            losses.append(loss.cpu().item())

    # 计算平均loss
    avg_loss = np.mean(losses)
    return avg_loss

另一个改动,就是调用valid_loop时不需要再传入batch_size,如果传入反而会报错。因为装饰器已经帮你设定了batch_size,你只需要告诉被装饰的函数除了batch_size以外的信息即可,如下

valid_losses = valid_loop_fn(tokenized_dataset['test'], model, data_collator)

经过一点微小的改动,再也不会因为batch_size太大而报错了
(*^▽^*)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值