在Linux系统下安装Apex库
1. 安装流程(按顺序使用如下命令)
git clone https://github.com/NVIDIA/apex
cd apex
pip3 install -v --no-cache-dir ./
注意:不能直接使用pip3 install apex来安装, 因为如果直接用pip3 install apex, 会导致在导入时候报错.
卸载
pip uninstall apex
2. 使用方法
只需要在原模型训练的代码中修改三处
(1)添加 from apex import amp;
(2)在定义完model和optimizer后,添加 model, optimizer = amp.initialize(model, optimizer, opt_level="O1");注意是字母O
(3)在模型训练部分代码中,注释掉 loss.backward(),使用如下代码来替换:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
3. Reference
1. https://nvidia.github.io/apex/index.html