1 前言:不推荐,容易出现nan
经过一段时间的探索,我们认为:amp
训练是不推荐的,容易出现nan
在debug过程中,我们发现在网络的前向运算中出现了nan,这一点已经在PyTorchForums提交了issue_nan_amp,目前还没有得到maintainer的回复;
该问题具体的表现:在前向运算中出现nan,但是1. 当前算子的参数不包含nan;2.并且输入数据不含有nan,目前无法判断出现nan的原因;
测试代码如下:
...
# forward pass
with torch.cuda.amp.autocast(self.use_amp):
x = self.backbone(img)
x = self.neck(x)
if self.qa.check_nan:
# 检查输入x是否含有nan
assert not x.isnan().any()
# 检查dict_feats["layer2"]是否含有nan
assert not self.dict_feats["layer2"].isnan().any()
for p in self.decoder2.parameters():
assert not p.isnan().any()
if isinstance(self.decoder2, FuseDecoder):
encode_data = self.decoder2(x, self.dict_feats["layer2"])
else:
raise NotImplementedError
if self.qa.check_nan:
assert not encode_data.isnan().any()
...
经过实验,出现以下的提示信息:
in train_epochs
output = self.model(batch_img)
File “/home/user/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1110, in _call_impl
return forward_call(*input, **kwargs)
File “/home/…net.py”, line 331, in forward
assert not encode_data.isnan().any()
AssertionError
可以看到,encode_data中出现nan,说明数值计算过程不是很稳定;
调试记录
“x = self.bn(x)”: BatchNorm2d
提示信息:
可以看到是在146行出现了错误,
代码截图:
是在BN层出现了错误,这里感觉BN层容易出现数值溢出(目前暂时没有什么比较好的解决方案);
2 容易出现nan的算子:BatchNorm2d
这里我们来记录一下容易出现nan的算子:nn.BatchNorm2d
3 AMP训练:torch.cuda.amp
示例代码:
# amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# GradScaler对象用来自动做梯度缩放
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 在autocast enable 区域运行forward
with autocast():
# model做一个FP16的副本,forward
output = model(input)
loss = loss_fn(output, target)
# 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
scaler.scale(loss).backward()
# scaler 更新参数,会先自动unscale梯度
# 如果有nan或inf,自动跳过
scaler.step(optimizer)
# scaler factor更新
scaler.update()
4 使用装饰器指定不使用amp
的模块:keep_forward_float_()
def keep_forward_float_(m):
def float_forward(self, x, forward):
assert isinstance(self, nn.Module)
with autocast(enabled=False):
return forward(x.float())
# x.float()指将输入数据转换为fp32类型
m.forward = MethodType(functools.partial(float_forward, forward=m.forward), m)
5 使用amp
和GradientAccumulation联合进行优化
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate # 看看这个是否可以省略
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
3 Troubleshooting
3.1 RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
运行时出现错误提示:
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are safe to autocast.
由提示信息可知,torch
规定无法在autocast
作用域中使用nn.BCELoss(reduction="none")
,于是,需要在代码中单独声明在计算BCE损失时不使用autocast
,示例代码如下:
with autocast(enabled=False):
bce = self.bce_loss(output_map.float(), target_map.float())
Note:
在autocast(enabled=False)
作用域中引用的tensor需要使用其float()
版本;请参考torch
官方示例amp_force_float32。