🎏目录
🎈1
🎈2
🎄2.1
🎄2.2
🎄2.3
🎄2.4
✨1 混合精度训练简介
目前,Pytorch一共支持10种数据类型:
- torch.FloatTensor # 另一种表述:FP32**
- torch.DoubleTensor # 64-bit floating point
- torch.HalfTensor # 另一种表述:FP16
- torch.ByteTensor
- torch.CharTensor
- torch.ShortTensor
- torch.IntTensor
- torch.LongTensor
默认使用的是32位浮点型精度的Tensor,即torch.FloatTensor
。因此,默认情况下我们训练的是一个FP32的模型。但不是所有数据都需要FP32那么大的内存。
此时,采用自动混合精度(Automatic Mixed Precision, AMP)训练,一部分算子数值精度为FP16,其余算子的数值精度是FP32,而哪些算子用FP16,哪些用FP32,由amp自动安排。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更多的 batch size、更大模型和尺寸更大的输入进行训练。
✨2 自动混合精度训练的使用
结合一段示例代码来看:
# amp依赖Tensor core架构,所以模型必须在cuda设备下使用
model = Model()
model.to("cuda") # 必须!!!
optimizer = optim.SGD(model.parameters(), ...)
# (新增)创建GradScaler对象
scaler = GradScaler(enabled=True) # 虽然默认为True,体验一下过程
for epoch in epochs:
for img, target in data:
optimizer.zero_grad()
# (新增)启动autocast上下文管理器
with autocast(enabled=True):
# (不变)上下文管理器下,model前向传播,以及loss计算自动切换数值精度
output = model(img)
loss = loss_fn(output, target)
# (修改)反向传播
scaler.scale(loss).backward()
# (修改)梯度计算
scaler.step(optimizer)
# (新增)scaler更新
scaler.update()
使用自动化精度时,只有在模型以及损失计算,反向传播,梯度更新时作出一定的改变,具体有:
scaler = GradScaler()
:创建对象GradScaler,并赋予变量scalerwith autocast():
:启动autocast
上下文管理器,内含需要做精度放缩的计算(必须包含模型计算以及损失计算)scaler.scale(loss).backward()
:利用scaler做反向传播scaler.step(optimizer)
:梯度更新scaler.update()
:scaler更新
🎃 2.2 GradScaler
构造:
torch.cuda.amp.GradScaler(
init_scale=65536.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
enabled=True,
)
这里形式很固定,只有一个参数enabled根据自己的需求进行改变
参数:
- enabled:是否做scale。如果为False,则返回原数据。如果为True,则进行一次精度转换。
原谅我,涉及到了原理,水平有限,真的看不懂,欢迎大家交流
🎉 2.2 autocast
autocast(
enable=True # 同上
)
上面的示例中,autocast是在训练脚本中使用的,除此之外还有两种方式:
- 作为装饰器,在forward函数中使用
==============================================================
class Model(nn.Module):
def __init__(self):
pass
@torch.cuda.amp.autocast() # autocast导入路径
def forward():
pass
==============================================================
- 在forward中使用上下文管理器
==============================================================
class Model(nn.Module):
def __init__(self):
pass
def forward():
with torch.cuda.amp.autocast(): # 上下文管理器
pass
==============================================================