安装
合精度训练加速神器–Apex
- 从https://download.pytorch.org/whl/cu100/torch_stable.html下载cuda10.0的pytorch
- 安装pytorch,然后进入python3命令行输入:
查看pytorch中cuda版本与机器本身安装的cuda版本,使得版本一致import torch torch.version.cuda
- https://github.com/NVIDIA/apex,克隆apex代码,并且切换分支:
git checkout f3a960f80244cf9e80558ab30f7f7e8cbf03c0a0
- 参考官方文档,编译并安装apex
apex使用
# Initialization
model = model.cuda()
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])