工作笔记|基于Apex的混合精度加速

from apex import amp

model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # not 01

with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

optimizer = torch.optim.Adam([{'params': model.backbone.parameters(), 'lr': 3e-5},
{'params': model.fc.parameters(), 'lr': 3e-4},   ])

混合精度计算 (Mixed Precision)
NVIDIA开发的基于PyTorch的混合精度训练加速器Apex
使用上述代码可以实现不同程度的混合精度加速,训练时间缩短一半。

展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 1024 设计师: 上身试试
应支付0元
点击重新获取
扫码支付

支付成功即可阅读