预训练模型输出nan的问题解决方法

问题描述

使用来自与训练集同一数据集的图片做推理测试时,预训练模型能正常提取图像特征,但是将其放到训练流程中时,模型输出为NAN
模型结构:预训练模型+下游任务层
使用的预训练模型:stability- AI的VAE

问题定位

模型训练途中出现NAN一般由三个原因导致:

  1. 输入数据存在非法值(输入数据中存在nan/inf等非法值)
  2. 模型参数存在nan值
  3. 模型自身原因,输入在模型中传输过程中某一层输出变为nan

输入数据原因排查

在该批次图片进入模型前,查看该批次图像的数据分布:

print("Input stats - mean: {}, std: {}, min: {}, max: {}".format(images.mean().item(), images.std().item(), images.min().item(), images.max().item()))

模型参数原因排查

 for name, param in self.model.named_parameters():
 	print(f"参数名称: {name}")
    print(f"参数值: {param.data}")
    if torch.isnan(param.data).any():
    	print(f"NaN detected after layer: {name}")

逐层排查预训练模型输出

写一个钩子函数:

def hook_function(module, output):  
    print(f"Module name: {module.__class__.__name__}")  
    print(f"Output shape: {output.shape}")  
    print("Output stats - mean: {}, std: {}, min: {}, max: {}".format(output.mean().item(), output.std().item(), output.min().item(), output.max().item()))
    if torch.isnan(output).any():
        print(f"NaN detected after layer: {module.__class__.__name__}")

钩子注册函数与移除函数:

def register_hooks(self):  
    # 遍历预训练模型的所有模块,为它们注册前向钩子  
    for name, module in self.model.named_modules():  
        if isinstance(module, nn.Module):  # 确保是模块,而非其他(如参数)  
            self.hooks.append(module.register_forward_hook(hook_function))  
  
def remove_hooks(self):  
    # 移除之前注册的所有钩子  
    for hook in self.hooks:  
        hook.remove()  

进预训练模型前注册钩子,结束后移除:

self.register_hooks()  
features = self.model(images)
self.remove_hooks()

经过排查后发现是预训练模型某层卷积之后出现了数值爆炸,导致了后层nan值的出现:
在这里插入图片描述

问题解决

由于我使用的是stability- AI的预训练VAE,所以去官方GitHub主页里去搜了搜有没有人遇到过类似的问题,结果确实在一条issue里看到:
在这里插入图片描述
由于模型训练时采用的BF16精度,所以如果使用fp16精度的话可能会造成无效数值/NAN/黑图的出现
所以重新配置我的模型的训练精度为BF16成功解决问题。

ps:
BFloat16 格式使用16位表示浮点数,其中1位用于符号,8位用于指数,7位用于尾数
Float16 格式使用16位表示浮点数,其中1位用于符号,5位用于指数,10位用于尾数
BFloat16的表示范围比Float16更广,但是精度更低
类似的问题还出现在训练大规模CLIP时,使用混合精度Float16训练会出现NAN,但是换成BFloat16就能解决该问题

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值