PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

640

作者丨Nicolas

单位丨追一科技AI Lab研究员

研究方向丨信息抽取、机器阅读理解

你想获得双倍训练速度的快感吗? 

你想让你的显卡内存瞬间翻倍吗? 

如果告诉你只需要三行代码即可实现,你信不? 

在这篇文章里,笔者会详解一下混合精度计算(Mixed Precision),并介绍一款 NVIDIA 开发的基于 PyTorch 的混合精度训练加速神器——Apex,最近 Apex 更新了 API,可以用短短三行代码就能实现不同程度的混合精度加速,训练时间直接缩小一半。 

话不多说,直接先教你怎么用。

PyTorch实现

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()

对,就是这么简单,如果你不愿意花时间深入了解,读到这基本就可以直接使用起来了。

但是如果你希望对 FP16 和 Apex 有更深入的了解,或是在使用中遇到了各种不明所以的“Nan”的同学,可以接着读下去,后面会有一些有趣的理论知识和笔者最近一个月使用 Apex 遇到的各种 bug,不过当你深入理解并解决掉这些 bug 后,你就可以彻底摆脱“慢吞吞”的 FP32 啦。

理论部分

为了充分理解混合精度的原理,以及 API 的使用,先补充一点基础的理论知识。

1. 什么是FP16?

半精度浮点数是一种计算机使用的二进制浮点数数据类型,使用 2 字节(16 位)存储。

640?wx_fmt=png

▲ FP16和FP32表示的范围和精度对比

 

其中, sign 位表示正负, exponent 位表示指数640?wx_fmt=png, fraction 位表示的是分数640?wx_fmt=png。其中当指数为零的时候,下图加号左边为 0,其他情况为 1。

640?wx_fmt=png

▲ FP16的表示范例

 

2. 为什么需要FP16?

在使用 FP16 之前,我想再赘述一下为什么我们使用 FP16。

  • 减少显存占用 现在模型越来越大,当你使用 Bert 这一类的预训练模型时,往往模型及模型计算就占去显存的大半,当想要使用更大的 Batch Size 的时候会显得捉襟见肘。由于 FP16 的内存占用只有 FP32 的一半,自然地就可以帮助训练过程节省一半的显存空间。

  • 加快训练和推断的计算 与普通的空间时间 Trade-off 的加速方法不同&#x

  • 25
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值