当BatchNormalize遇到异常值,导致预测结果都是NaN

问题背景

最近使用DNN模型来做排序,发现近几天的NDCG离线指标下跌得很厉害。于是下载模型自己在本地评测了一下,预测结果都是NaN,于是把各层的模型参数以及各层的输出都打印出来,发现BatchNormalize中的moving_variance(方差)的某一维是NaN,最后一查果然是这一维特征异常了。为了把事情弄清楚,写这个blog记录一下。

BatchNormalize(BN)基础知识

BN的提出是为了解决神经网络中Internal Covariate Shift的问题,Internal Covariate Shift简单地说就是各层网络的输出会产生分布的变化,而分布变化使神经网络比较难收敛。BN的思想也很简单,就是通过把输入各个特征的分布转化成均值为0,方差为1的正态分布上去。训练过程是在mini-batch中进行操作的,具体也就是求出均值、标准差,通过下边式子对输入特征进行转化。
在这里插入图片描述为了保留原来的特征分布信息,加入了可学习的参数:

在这里插入图片描述
可以看出,当伽马等于标准差,贝塔等于均值时就完全还原了输入。

但是在模型预测时mini-batch可能只有一个实例,所以在实现中的做法是保留每一个batch计算的均值和方差,通过平滑的方式计算预测使用的均值和方差:

问题复现

使用了如下代码进行了问题的复现:

import tensorflow as tf
import numpy as np

bn_input = tf.cast([[1, 1, 3],
                    [2, 5, 6]], tf.float32)
bn_layer = tf.compat.v1.layers.batch_normalization(bn_input, training=True, momentum=0)

with tf.compat.v1.Session() as sess:
    update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
    variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='batch_normalization')
    print('UPDATE_OPS ----------------------->\n', update_ops)
    print('GLOBAL_VARIABLES ----------------------->\n', variables)
    moving_mean = variables[2]
    moving_variance = variables[3]
    init = tf.compat.v1.global_variables_initializer()
    r1 = sess.run(init)
    print(r1)
    r2 = sess.run([update_ops, bn_layer])
    print(r2)
    result= sess.run([moving_mean, moving_variance])
    print(result)

结果如下:
在这里插入图片描述
我们看到均值、方差以及输出都是符合预期的。
当在input插入一个很大的值时,我们可以看到variance是nan。可以理解,输入的差异非常大时,方差会非常大,在BN计算时就会出现nan的现象。

bn_input = tf.cast([[1, 1, 3],
                    [2e44, 5, 6]], tf.float32)

在这里插入图片描述

复盘总结

在查问题的时候,开始就规规矩矩的看样本分布是否有变化,看特征是否有变化。其实使用BN的过程中就是可以很好的debug模型,监控特征分布的变化情况。更重要地,是系统性的解决这样的鲁棒性问题,比如对特征进行分桶,如果一定要用连续特征时需要对特征变化进行监控。
好了,留下一些思考题:从模型的角度看BN层似乎不能很好的应对异常值,那什么样的模型结构能够优雅地处理异常值呢?GBDT?如果直接使用全连接,结果会怎么样呢?

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值