一、混合精度训练
指同时使用单精度(FP32)和半精度(FP16)进行训练,有实验证明在保证模型效果不变的情况下,使用混合精度训练可以有效加快训练时间、减少网络训练时候所占用的内存。具体原理可以参见博文:全网最全-混合精度训练原理 ,这里就不做赘述。
总结而言混合精度训练有以下好处:
①FP16只占用通常使用的FP32一半的显存。
②N卡在对FP16计算速度比FP32快上许多
二、PyTorch中的混合精度训练
!!!使用此方法需要PyTorch版本>1.6。调用自带的torch.cuda.amp模块实现。
①检查是存在amp模块(版本>1.6)
3----fp16为混合精度开启的标志----#
if fp16:
from torch.cuda.amp import GradScaler as GradScaler
scaler = GradScaler()
else:
scaler = None
②导入amp模块
from torch.cuda.amp import autocast
③训练中进行数据转换
需要转换损失函数、迭代器,同时需要使用update函数
with autocast():
#----模型代码----#
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
三、存在的问题
由于前向传递过程中一些求和操作可能会导致数据溢出(上溢),所有可能出现nan的报错。则需要在模型中容易上溢的部分使用FP32算法。
上溢的部分采用debug的方式找出。
#----禁用amp自动切换----#
with torch.cuda.amp.autocast(enable=False):
#----将数据转换会fp32----#
value = value.to(torch.float32)