在AMP下训练时使用梯度剪裁的范例
for batch_idx, (inputs, targets) in progress_bar:
optimizer.zero_grad()
with autocast(enabled=use_amp):
outputs = model(inputs)
loss = criterion(outputs, targets)
running_train_loss += loss.item()
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # 根据标准范例,在梯度裁剪之前进行unscale
grad_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(),
max_norm=2.0,
norm_type=2,
error_if_nonfinite=False,)
scaler.step(optimizer)
scaler.update()
在裁剪前需要进行unscale操作
scaler.unscale_(optimizer)
若不进行裁剪,则clip函数返回的grad_norm如下
正常的梯度曲线应当如下
大致原因
配合AMP的GradScaler的目的是缓解在低精度下出现的梯度消失或爆炸问题。其会对计算出来的loss进行缩放后再进行梯度计算,因此计算出来的梯度是对应于被缩放后的loss的。
若不使用unscale手动反缩放,则程序也将在调用scaler.step(optimizer)
时自动反缩放。
问题在于进行梯度裁剪时,设定的裁剪参数是基于正常情况的loss的,而不是缩放后的loss。所以需要手动提前进行unscale,而后再行裁剪,此时裁剪的对象就是正常loss了。
手动调用unscale后,scaler.step(optimizer)
将不再自动反缩放。若在程序中进行了两次反缩放,会抛出异常。
异常梯度曲线的阶跃点似乎是固定在某个epoch上的
batch_size对阶跃点的出现有影响。
我暂时没有很详细地了解过为什么阶跃点总是在一些特别的节点上,如:epoch400、epoch2000…