问题描述
使用来自与训练集同一数据集的图片做推理测试时,预训练模型能正常提取图像特征,但是将其放到训练流程中时,模型输出为NAN
模型结构:预训练模型+下游任务层
使用的预训练模型:stability- AI的VAE
问题定位
模型训练途中出现NAN一般由三个原因导致:
- 输入数据存在非法值(输入数据中存在nan/inf等非法值)
- 模型参数存在nan值
- 模型自身原因,输入在模型中传输过程中某一层输出变为nan
输入数据原因排查
在该批次图片进入模型前,查看该批次图像的数据分布:
print("Input stats - mean: {}, std: {}, min: {}, max: {}".format(images.mean().item(), images.std().item(), images.min().item(), images.max().item()))
模型参数原因排查
for name, param in self.model.named_parameters():
print(f"参数名称: {name}")
print(f"参数值: {param.data}")
if torch.isnan(param.data).any():
print(f"NaN detected after layer: {name}")
逐层排查预训练模型输出
写一个钩子函数:
def hook_function(module, output):
print(f"Module name: {module.__class__.__name__}")
print(f"Output shape: {output.shape}")
print("Output stats - mean: {}, std: {}, min: {}, max: {}".format(output.mean().item(), output.std().item(), output.min().item(), output.max().item()))
if torch.isnan(output).any():
print(f"NaN detected after layer: {module.__class__.__name__}")
钩子注册函数与移除函数:
def register_hooks(self):
# 遍历预训练模型的所有模块,为它们注册前向钩子
for name, module in self.model.named_modules():
if isinstance(module, nn.Module): # 确保是模块,而非其他(如参数)
self.hooks.append(module.register_forward_hook(hook_function))
def remove_hooks(self):
# 移除之前注册的所有钩子
for hook in self.hooks:
hook.remove()
进预训练模型前注册钩子,结束后移除:
self.register_hooks()
features = self.model(images)
self.remove_hooks()
经过排查后发现是预训练模型某层卷积之后出现了数值爆炸,导致了后层nan值的出现:
问题解决
由于我使用的是stability- AI的预训练VAE,所以去官方GitHub主页里去搜了搜有没有人遇到过类似的问题,结果确实在一条issue里看到:
由于模型训练时采用的BF16精度,所以如果使用fp16精度的话可能会造成无效数值/NAN/黑图的出现
所以重新配置我的模型的训练精度为BF16成功解决问题。
ps:
BFloat16 格式使用16位表示浮点数,其中1位用于符号,8位用于指数,7位用于尾数
Float16 格式使用16位表示浮点数,其中1位用于符号,5位用于指数,10位用于尾数
BFloat16的表示范围比Float16更广,但是精度更低
类似的问题还出现在训练大规模CLIP时,使用混合精度Float16训练会出现NAN,但是换成BFloat16就能解决该问题