python 调用gpu算力_GPU捉襟见肘还想训练大批量模型?谁说不可以

深度学习模型和数据集的规模增长速度已经让 GPU 算力也开始捉襟见肘,如果你的 GPU 连一个样本都容不下,你要如何训练大批量模型?通过本文介绍的方法,我们可以在训练批量甚至单个训练样本大于 GPU 内存时,在单个或多个 GPU 服务器上训练模型。

2018 年的大部分时间我都在试图训练神经网络时克服 GPU 极限。无论是在含有 1.5 亿个参数的语言模型(如 OpenAI 的大型生成预训练 Transformer 或最近类似的 BERT模型)还是馈入 3000 万个元素输入的元学习神经网络(如我们在一篇 ICLR 论文《Meta-Learning a Dynamical Language Model》中提到的模型),我都只能在 GPU 上处理很少的训练样本。

但在多数情况下,随机梯度下降算法需要很大批量才能得出不错的结果。如果你的 GPU 只能处理很少的样本,你要如何训练大批量模型?

有几个工具、技巧可以帮助你解决上述问题。在本文中,我将自己用过、学过的东西整理出来供大家参考。

在这篇文章中,我将主要讨论 PyTorch 框架。有部分工具尚未包括在 PyTorch(1.0 版本)中,因此我也写了自定义代码。

我们将着重探讨以下问题:在训练批量甚至单个训练样本大于 GPU 内存,要如何在单个或多个 GPU 服务器上训练模型;

如何尽可能高效地利用多 GPU 机器;

在分布式设备上使用多个机器的最简单训练方法。

在一个或多个 GPU 上训练大批量模型

你建的模型不错,在这个简洁的任务中可能成为新的 SOTA,但每次尝试在一个批量处理更多样本时,你都会得到一个 CUDA RuntimeError:内存不足。

aa5d631f87a9697c365d3fa253f39732.png

这位网友指出了你的问题!

但你很确定将批量加倍可以优化结果。你要怎么做呢?

这个问题有一个简单的解决方法:梯度累积。

8d69fbbce769d993b8c586e16ae18346.png

梯度下降优化算法的五个步骤。

与之对等的 PyTorch 代码也可以写成以下五行:predictions = model(inputs)               # Forward pass

loss = loss_function(predictions, labels) # Compute loss function

loss.backward()                           # Backward pass

optimizer.step()                          # Optimizer step

predictions = model(inputs)               # Forward pass with new parameters

在 loss.backward() 运算期间,为每个参数计算梯度,并将其存储在与每个参数相关联的张量——parameter.grad 中。

累积梯度意味着,在调用 optimizer.step() 实施一步梯度下降之前,我们会对 parameter.grad 张量中的几个反向运算的梯度求和。在 PyTorch 中这一点很容易实现,因为梯度张量在不调用 model.zero_grad() 或 optimizer.zero_grad() 的情况下不会重置。如果损失在训练样本上要取平均,我们还需要除以累积步骤的数量。

以下是使用梯度累积训练模型的要点。在这个例子中,我们可以用一个大于 GPU 最大容量的 accumulation_steps 批量进行训练:model.zero_grad()                                   # Reset gradients tensors

for i, (inputs, labels) in enumerate(t

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值