pytorch使用Apex混合精度加速训练

Apex官网:https://nvidia.github.io/apex/amp.html#

使用原因:

这篇博客讲的非常好
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速

1.安装

使用pip安装后会出错

TypeError: Class advice impossible in Python3. Use the @Implementer class decorator instead.

解决方法:

pip uninstall apex
git clone https://www.github.com/nvidia/apex
cd apex
python setup.py install

2.使用

核心代码:

from apex import amp
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # “欧一”,不是“零一”
with amp.scale_loss(loss, optimizer) as scaled_loss:
		scaled_loss.backward()

例子:

原始训练代码:

import torch
ngpu=2
def traiin():
		model = torch.nn.Linear(D_in, D_out).cuda()
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		for img, label in dataloader:
			out = model(img.half())
			loss = LOSS(out, label)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
#此时采用全精度32位来训练

半精度训练:

import torch
ngpu=2
def traiin():
		model = torch.nn.Linear(D_in, D_out).cuda().half()
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		for img, label in dataloader:
			out = model(img.half())
			loss = LOSS(out, label)
			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
#此时采用半精度16位来训练

显存基本可以降低为原来的一半,但训练速度降低,可能原因是,CUDNN只支持float32加速,半精度后,将不能加速。

混合精度训练:

import torch
ngpu=2
def train():
		model = torch.nn.Linear(D_in, D_out).cuda()
		optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
		#设置混合精度模式为O1(欧1,不是零1,后面会解释各个模式区别)
		model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
		model = torch.nn.DataParallel(model, device_ids=[i for i in range(ngpu)])
		for img, label in dataloader:
			out = model(img)
			loss = LOSS(out, label)
			#将loss进行缩放,防止溢出
			with amp.scale_loss(loss, optimizer) as scaled_loss:
		    	scaled_loss.backward()
		
			optimizer.step()
			optimizer.zero_grad()
def save_model(self, epoch):
        if self.mixed_precision:
            import apex.amp as amp
            amp_state_dict = amp.state_dict()
        else:
            amp_state_dict = None
        checkpoint = {
            'epoch': epoch,
            'params': self.params,
            'model': self.model.module.state_dict() if self.ngpu > 1 else self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'amp': amp_state_dict
        }
        torch.save(checkpoint, os.path.join(self.expdir,'model.pt'))

def load_model(self, checkpoint):
	  state_dict = torch.load(checkpoint)
	  self.model.load_state_dict(state_dict['model'])
	
	  if self.mixed_precision:
	      import apex.amp as amp
	      amp.load_state_dict(state_dict['amp'])

注意:
1.模型在amp.initialize前必须加载到GPU上。
2.amp.initialize前不能对模型进行任何分布式操作,如torch.nn.DataParallel必须放在之后。

opt_level解释
O0纯 FP32 训练,可以作为 accuracy 的 baseline
O1混合精度训练(推荐使用),根据黑白名单自动决定使用 FP16(GEMM, 卷积)还是 FP32(Softmax)进行计算
O2几乎FP16混合精度训练,不存在黑白名单,除了 Batch Norm,几乎FP16 计算
O3纯 FP16 训练,很不稳定,但是可以作为 speed 的 baseline

参考:
PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速 [CSDN]
Apex [官网]
Apex混合精度加速 [码农网]

  • 3
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值