梯度消失的有效解决方法-batch normalization

batch-normalization用在激活函数前的那一层,作用是调整该层均值和方差(一般是均值为0,方差为1),可以有效减少梯度消失问题。
使用tensorflow的实现过程:

#模拟一个tensor(CNN的tensor是四维,这里做了简化)
#同时初始化时方差设置的很小模拟梯度消失时中间层结点的情况
img=tf.Variable(tf.random_normal([4,4],stddev=0.00001),dtype=tf.float32)

#这里的axis代表在哪些维度上求均值和方差,对于CNN的四维tensor来说,一般是对前三维度加总求均值和方差
axis=list(range(len(img.get_shape())-1))

#调用tf.nn.moments函数求均值和方差,因为我们是对前三维求均值和方差,所以为mean和variance长度为第四维长度
mean,var=tf.nn.moments(img,axis)

#调用tf.nn.batch_normalization时相当于进行如下计算
#normalized = (img - mean) / tf.sqrt(var + epsilon)
#normalized = normalized * scale + shift
#所以offset相当对规范化后的分布进行调整
offset=tf.Variable(tf.zeros([mean.get_shape().as_list()[0]]))
scale=tf.Variable(tf.ones([mean.get_shape().as_list()[0]]))

#注意epsilon的值,epsilon的作用是防止分母为0
#但是如果epsilon比var大一个数量级的话,那var就没有用了,而会将整个分布扩大epsilon的倒数倍
#所以这里将epsilon设置为很小的值,为了方便观察结果(忽略epsilon的影响)
epsilon=0.00000000001
normalized=tf.nn.batch_normalization(img,mean,var,offset,scale,epsilon)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('img')
    print(sess.run(img))
    print('normalized')
    print(sess.run(normalized))
    print('var')
    print(sess.run(var))
    print('mean')
    print(sess.run(mean))

实验结果如下所示:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值