BatchNorm层测试出现nan的查bug之旅

                                                              BatchNorm层测试出现nan的查bug之旅

1、前言

在caffe中使用BatchNorm层很简单,只要注意一点,在训练时将use_global_states设为false,测试前向阶段将use_global_states设为true即可。在tensorflow中,其实也不难(具体可参考我的另一外博客:https://blog.csdn.net/LxDamon/article/details/108762087),稍微复杂一点点,只要注意batchnorm的mean和variance不是trainable的变量,只是会计算更新,简单理解就是它们不参与反向传播,只有BN层的尺度因子gamma和偏移因子beta是需要学习即进行反向传播的,这里就不展开讲BN的原理(后面会单拎一篇博客仔细讲解与分析BN的原理)。在tensorflow中首先需要利用tf.GraphKeys.UPDATE_OPS将需要更新的mean和variance变量收集到update_ops列表中,然后训练时设置trainable为True,测试推理时将trainable设为False,一般做好这两点,基本BN就可以正常使用,但是接下来的bug却困扰到我了。

2、bug的现象

训练过程中利用tensorboard显示网络输出结果图是正常的,但是训练完成后,前向推理时却发现网络输出结果是全黑的:

3、解决bug之旅

出现上述含有BN层测试结果出错全黑的bug,首先对网络每一层的输出进行打印,定位到了出错的地方,发现是conv10输出结果为nan,该层的结构是input->conv->BN->leaky_relu。我一开始就怀疑是BN的原因,所以直接对BN的输入与输出进行打印,从数值上看,BN的输入看上去是正常的,但是它的输出是nan:

BN的输入:

BN的输出:

这样就定位到了的确是BN层出问题了,开始我还很有信心,觉得应该要么就是trainable设置不对,要么就是update_ops不对,但是最后的结果都出乎我的意料之外。

首先检查的就是tainable是否设置正确,发现tainable设置是对的,训练时设置的是True,测试时设置是False。然后去检查update_ops列表中是否含有BN层需要更新的变量即滑动平均moving_mean和滑动方差moving_variance,发现列表中确实存在需要更新的变量:

这下我就纳闷了,发现tensorflow中使用BN经常出错的两个坑我都没有踩,但是还是出现了很诡异的错误。然后我就怀疑是不是保存模型的时候mean和std没有保存进去,于是我就去打印ckpt模型,打印模型的代码如下:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from tensorflow.python import pywrap_tensorflow

def main():
    ckpt_path='./ckpt/model.ckpt-300000'
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        str_s = str(key)
        if str_s.find('moving_mean') > 0 or str_s.find('moving_variance') > 0:
            print("tensor_name: ", key)
            print(reader.get_tensor(key))

if __name__ == '__main__':
    main()

打印ckpt模型并且利用BN的moving_mean和moving_variance关键字提取信息,发现ckpt模型中是存在均值和方差的,说明保存模型的代码是没有问题的:

但是我们从上图可以看出,虽然模型中存在滑动均值与滑动方差,但是moving_mean的值特别小,moving_variance的值直接为nan,这下就说明问题了,说明在训练过程中均值与方差的更新就是不对的,可能有些小伙伴会纳闷了,前面不是说过在训练过程中利用tensorboard看网络输出是好的嘛,是的,注意一点,在训练过程中BN所使用的是每个batch里面在线计算的均值和方差,并没有用到滑动平均的均值和方差。为了验证我这个想法,首先很简单,我在测试代码里直接将BN的trainlabel改成True,结果就正常了,就不是全黑图了。

为了进一步验证我的想法和找到bug的本质以及解决它,我在训练过程中将BN的输入、BN的moving_mean和moving_variance也打印出来,代码片段为:

BN_mean,BN_variance = sess.run(['generator/batch_normalization_10/moving_mean: 0',
	'generator/batch_normalization_10/moving_variance: 0'])
print('****************step****************', i)
print('****************BN_mean_10****************',BN_mean)
print('****************BN_variance_10****************', BN_variance)

训练中打印信息为:

由于我已经验证过训练过程中的输入是没有问题,从训练中网络的输出结果也可以判断每一层的输入是肯定没有问题的。但从上图中发现mean的非常小差不多都在10的负8次方,variance的值为nan,说明了问题还得出现在训练过程中滑动均值与方差的计算和更新出现问题了。开始我怀疑是不是动量参数momentum比较大,所以导致了这个问题,于是我将momentum从0.99减少为0.9,打印信息为:

从上图7可以看出,将动量参数该小之后,moving_mean的值果然变的大了一点,从10的负8次方增大到了10的负6次方,感觉有戏,于是为了进一步验证是不是该参数导致的问题,直接来个极端的,将momentum设为0.1,打印信息为:

发现moving_mean的值没有再继续增大了,并且moving_variance的值仍为nan。所以这条路又不通了,可以判定不是这个参数的问题,于是我想是不是我使用BN层时有些参数没有设对,于是我把tf.layers.batch_normalization函数的所有传参研究个透,并且多次尝试参数的设置,均没有解决上述bug。这下没辙了,感觉好奇怪,啥问题也没有,为啥会出现nan了,没办法,去看看tf.layers.batch_normalization函数的源码,想着在源码里可能发现一些蛛丝马迹,果然发现了问题,batch_size的大小必须大于1,不能在tensorflow中更新mean和variance走的就不是正常的流程,我回过头查看了一下我的batch_size为1,于是我将batch_size改为2,训练中打印BN层的mean和variance如下:

发现BN层的moving_mean和moving_variance值终于正常了,最后测试时网络输出结果也正常了,所以使用BN层时要特别注意规避上述那些坑。

有问题欢迎评论交流,一起进步!

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值