apex 安装、应用
按照以下方式安装apex
pip uninstall apex
git clone https://www.github.com/nvidia/apex
cd apex
python setup.py install
快速应用
from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # 这里是“欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
opt_level需要用户自行配置:推荐用O1
O0:纯FP32训练,可以作为accuracy的baseline;
O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
O3:纯FP16训练,很不稳定,但是可以作为speed的baseline;
check point
# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
clip_grad_norm_(model.parameters(), 50) # 梯度裁剪
optimizer.step()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
注意,我们建议使用相同的opt_level恢复模型。还要注意,我们建议在amp.initialize之后调用load_state_dict方法
常见错误
由于apex报错并不明显,常常debug得让人很沮丧,但只要注意到以下的点,95%的情况都可以畅通无阻了:
- 判断你的GPU是否支持FP16:支持的有拥有Tensor Core的GPU(2080Ti、Titan、Tesla等),不支持的(Pascal系列)就不建议折腾了。
- 常数的范围:为了保证计算不溢出,首先要保证人为设定的常数(包括调用的源码中的)不溢出,如各种epsilon,INF(改成-float(‘inf’)就可以啦)等。
- Dimension最好是8的倍数:Nvidia官方的文档的2.2条表示,维度都是8的倍数的时候,性能最好。
- 涉及到sum的操作要小心,很容易溢出,类似Softmax的操作建议用官方API,并定义成layer写在模型初始化里。模型书写要规范:自定义的Layer写在模型初始化函数里,graph计算写在forward里**。
- 某些不常用的函数,在使用前需要注册:
amp.register_float_function(torch, 'sigmoid')
[某些函数(如einsum)暂不支持FP16加速,建议不要用的太heavy,xlnet的实现改FP16困扰了我很久。]大误,参考issue 802 slayton58的说法,注册好就可以强制加速了。 需要操作模型参数的模块(类似EMA),要使用AMP封装后的model。需要操作梯度的模块必须在optimizer的step里,不然AMP不能判断grad是否为Nan。 - 监督
参考:https://zhuanlan.zhihu.com/p/79887894
https://nvidia.github.io/apex/amp.html
https://www.github.com/nvidia/apex