使用`amp`进行GPU运算优化的学习笔记(会出现nan)

1 前言:不推荐,容易出现nan

经过一段时间的探索,我们认为:amp训练是不推荐的,容易出现nan
在debug过程中,我们发现在网络的前向运算中出现了nan,这一点已经在PyTorchForums提交了issue_nan_amp,目前还没有得到maintainer的回复;
该问题具体的表现:在前向运算中出现nan,但是1. 当前算子的参数不包含nan;2.并且输入数据不含有nan,目前无法判断出现nan的原因;
测试代码如下:

...
# forward pass
with torch.cuda.amp.autocast(self.use_amp):
    x = self.backbone(img)
    x = self.neck(x)

    if self.qa.check_nan:
    	# 检查输入x是否含有nan
        assert not x.isnan().any()
        # 检查dict_feats["layer2"]是否含有nan
        assert not self.dict_feats["layer2"].isnan().any()
        for p in self.decoder2.parameters():
            assert not p.isnan().any()

    if isinstance(self.decoder2, FuseDecoder):
        encode_data = self.decoder2(x, self.dict_feats["layer2"])
    else:
        raise NotImplementedError

if self.qa.check_nan:
    assert not encode_data.isnan().any()
...

经过实验,出现以下的提示信息:

in train_epochs
output = self.model(batch_img)
File “/home/user/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1110, in _call_impl
return forward_call(*input, **kwargs)
File “/home/…net.py”, line 331, in forward
assert not encode_data.isnan().any()
AssertionError

可以看到,encode_data中出现nan,说明数值计算过程不是很稳定;

调试记录

“x = self.bn(x)”: BatchNorm2d

提示信息:
在这里插入图片描述
可以看到是在146行出现了错误,
代码截图:
在这里插入图片描述
是在BN层出现了错误,这里感觉BN层容易出现数值溢出(目前暂时没有什么比较好的解决方案);

2 容易出现nan的算子:BatchNorm2d

这里我们来记录一下容易出现nan的算子:nn.BatchNorm2d

3 AMP训练:torch.cuda.amp

示例代码:

# amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 在autocast enable 区域运行forward
        with autocast():
            # model做一个FP16的副本,forward
            output = model(input)
            loss = loss_fn(output, target)
        # 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
        scaler.scale(loss).backward()
        # scaler 更新参数,会先自动unscale梯度
        # 如果有nan或inf,自动跳过
        scaler.step(optimizer)
        # scaler factor更新
        scaler.update()

4 使用装饰器指定不使用amp的模块:keep_forward_float_()

def keep_forward_float_(m):
    def float_forward(self, x, forward):
        assert isinstance(self, nn.Module)
        with autocast(enabled=False):
            return forward(x.float())  
            # x.float()指将输入数据转换为fp32类型

    m.forward = MethodType(functools.partial(float_forward, forward=m.forward), m)

5 使用amp和GradientAccumulation联合进行优化

scaler = GradScaler()

for epoch in epochs:
    for i, (input, target) in enumerate(data):
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
            loss = loss / iters_to_accumulate # 看看这个是否可以省略

        # Accumulates scaled gradients.
        scaler.scale(loss).backward()

        if (i + 1) % iters_to_accumulate == 0:
            # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

3 Troubleshooting

3.1 RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.

运行时出现错误提示:

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast.

由提示信息可知,torch规定无法在autocast作用域中使用nn.BCELoss(reduction="none"),于是,需要在代码中单独声明在计算BCE损失时不使用autocast,示例代码如下:

with autocast(enabled=False):
    bce = self.bce_loss(output_map.float(), target_map.float())

Note:
autocast(enabled=False)作用域中引用的tensor需要使用其float()版本;请参考torch官方示例amp_force_float32

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值