问题背景
最近使用DNN模型来做排序,发现近几天的NDCG离线指标下跌得很厉害。于是下载模型自己在本地评测了一下,预测结果都是NaN,于是把各层的模型参数以及各层的输出都打印出来,发现BatchNormalize中的moving_variance(方差)的某一维是NaN,最后一查果然是这一维特征异常了。为了把事情弄清楚,写这个blog记录一下。
BatchNormalize(BN)基础知识
BN的提出是为了解决神经网络中Internal Covariate Shift的问题,Internal Covariate Shift简单地说就是各层网络的输出会产生分布的变化,而分布变化使神经网络比较难收敛。BN的思想也很简单,就是通过把输入各个特征的分布转化成均值为0,方差为1的正态分布上去。训练过程是在mini-batch中进行操作的,具体也就是求出均值、标准差,通过下边式子对输入特征进行转化。
为了保留原来的特征分布信息,加入了可学习的参数:
可以看出,当伽马等于标准差,贝塔等于均值时就完全还原了输入。
但是在模型预测时mini-batch可能只有一个实例,所以在实现中的做法是保留每一个batch计算的均值和方差,通过平滑的方式计算预测使用的均值和方差:
问题复现
使用了如下代码进行了问题的复现:
import tensorflow as tf
import numpy as np
bn_input = tf.cast([[1, 1, 3],
[2, 5, 6]], tf.float32)
bn_layer = tf.compat.v1.layers.batch_normalization(bn_input, training=True, momentum=0)
with tf.compat.v1.Session() as sess:
update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
variables = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='batch_normalization')
print('UPDATE_OPS ----------------------->\n', update_ops)
print('GLOBAL_VARIABLES ----------------------->\n', variables)
moving_mean = variables[2]
moving_variance = variables[3]
init = tf.compat.v1.global_variables_initializer()
r1 = sess.run(init)
print(r1)
r2 = sess.run([update_ops, bn_layer])
print(r2)
result= sess.run([moving_mean, moving_variance])
print(result)
结果如下:
我们看到均值、方差以及输出都是符合预期的。
当在input插入一个很大的值时,我们可以看到variance是nan。可以理解,输入的差异非常大时,方差会非常大,在BN计算时就会出现nan的现象。
bn_input = tf.cast([[1, 1, 3],
[2e44, 5, 6]], tf.float32)
复盘总结
在查问题的时候,开始就规规矩矩的看样本分布是否有变化,看特征是否有变化。其实使用BN的过程中就是可以很好的debug模型,监控特征分布的变化情况。更重要地,是系统性的解决这样的鲁棒性问题,比如对特征进行分桶,如果一定要用连续特征时需要对特征变化进行监控。
好了,留下一些思考题:从模型的角度看BN层似乎不能很好的应对异常值,那什么样的模型结构能够优雅地处理异常值呢?GBDT?如果直接使用全连接,结果会怎么样呢?