使用AMP时,进行梯度剪裁需要注意的细节

在AMP下训练时使用梯度剪裁的范例

for batch_idx, (inputs, targets) in progress_bar:
        optimizer.zero_grad()
        with autocast(enabled=use_amp):
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_train_loss += loss.item()

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)  # 根据标准范例,在梯度裁剪之前进行unscale
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), 
                                                    max_norm=2.0,
                                                    norm_type=2,
                                                    error_if_nonfinite=False,)
        scaler.step(optimizer)
        scaler.update()

在裁剪前需要进行unscale操作

scaler.unscale_(optimizer)

若不进行裁剪,则clip函数返回的grad_norm如下
具有阶跃特征
正常的梯度曲线应当如下
不具有阶跃特征

大致原因

配合AMP的GradScaler的目的是缓解在低精度下出现的梯度消失或爆炸问题。其会对计算出来的loss进行缩放后再进行梯度计算,因此计算出来的梯度是对应于被缩放后的loss的。
若不使用unscale手动反缩放,则程序也将在调用scaler.step(optimizer)时自动反缩放。

问题在于进行梯度裁剪时,设定的裁剪参数是基于正常情况的loss的,而不是缩放后的loss。所以需要手动提前进行unscale,而后再行裁剪,此时裁剪的对象就是正常loss了。

手动调用unscale后,scaler.step(optimizer)将不再自动反缩放。若在程序中进行了两次反缩放,会抛出异常。

异常梯度曲线的阶跃点似乎是固定在某个epoch上的

batch_size对阶跃点的出现有影响。
我暂时没有很详细地了解过为什么阶跃点总是在一些特别的节点上,如:epoch400、epoch2000…

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值