深度学习训练降低显存指南

一、小模块API参数inplace设置为True(省一点点)

比如:Relu()有一个默认参数inplace,默认设置为False,当设置为True时,计算时的得到的新值不会占用新的空间而是直接覆盖原来的值,进而可以节省一点点内存。

二、Apex半精度计算(省一半左右)

安装方式

git clone https://github.com/NVIDIA/apex
cd apex
python3 setup.py install

原理:一款Nvidia开发的基于PyTorch的混合精度训练加速神器,可以用短短三行代码(导包、初始化、反向传播)就能实现不同程度的混合精度加速,训练时间直接缩小一半。

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

其中只有一个opt_level需要用户自行配置(# 这里是“欧x”,不是“零x”):

  • O0:纯FP32训练,可以作为accuracy的baseline;
  • O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
  • O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
  • O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;

注:apex不能和contiguous-params一起用,否则无法实现模型参数优化。(因为这俩模型参数指针要打架)

三、删无用变量并及时清空显存垃圾(省将近一半)

这个的用法比较简单,就是先del变量,再cuda垃圾回收。

del answers_types_emb_all, answers_types_len
torch.cuda.empty_cache()

但这个的使用方式,不能在整个代码中只使用一次torch.cuda.empty_cache()——否则整体的最大显存不会变化。

正确的使用方法是:应该在整体代码的多个位置,分别del变量后,对应位置之后即使用torch.cuda.empty_cache()。

这样整体最大显存,就会下降(del的变量整体显存之和 - max的一坨del变量的显存)——分批清空显存,才能真正地用torch.cuda.empty_cache()降低显存。

四、重计算方法(省一半以上)

用pytorch的checkpoint断点保存机制,将连续的几个函数计算,分为若干段,保存了再计算——用保存加载的时间,换显存计算的空间

原来一个函数的计算方式:

sequence_enc = self.encoder(sequence_emb, sequence_mask, output_all_encoded_layers=False)[-1]

用checkpoint断点保存的方式:

from torch.utils.checkpoint import checkpoint
sequence_enc = checkpoint(self.encoder, sequence_emb, sequence_mask, torch.tensor([0]))[-1]

值得一提的是,checkpoint函数,要求输入的函数(如self.encoder),返回值必须为torch.tensor类型,如果原来函数(self.encoder)的返回值是一个数组

return all_encoder_layers

用torch.stack改造一下,就可以用checkpoint了

return torch.stack(all_encoder_layers, dim=0)

注:checkpoint后没有梯度记忆,解决方法参考Pytorch使用GPU进行训练注意事项 | 文艺数学君 (mathpretty.com)的方法二

五、效果

楼主同时用了上述四项节省显存的技术,将原始显存大于11GB的程序,降低到了只有2GB-3GB的水平

上述技术节省显存的程度(由多到少):

  1. 重计算方法
  2. Apex半精度计算
  3. 删无用变量并及时清空显存垃圾
  4. 小模块API参数inplace设置为True

上述技术实现的复杂程度(由易到难):

  1. 小模块API参数inplace设置为True
  2. Apex半精度计算
  3. 删无用变量并及时清空显存垃圾
  4. 重计算方法

PS 你如果有24GB以上显存的显卡的话,那就当我什么都没说。

当然还有些不是方法的方法:降低Batch的大小,减小模型的超参数,不需要计算显存时用torch.no_grad():等。

此外,代码变量之间的不当依赖,会导致显存持续增长,具体情况详见显存持续缓慢增长的究极原因 - 知乎 (zhihu.com)

  • 2
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值