关于Batch Normalization(批标准化)的理解与代码实现

今天看了《深度学习》中关于批标准化的小节,一开始感觉有些困惑,后来搜集了资料后也有了自己的理解,总结如下。


概念问题

我认为要理解批标准化首先要理解标准化概念。
那么,什么是标准化?

通过中心化或标准化处理,得到均值为0,标准差为1的服从标准正态分布的数据。

事实证明,一个神经网络接收一张白化(令像素点标准化)过后的图片作为输入数据,那么其收敛速度较快。那么将此实例延申,放入神经网络所有结构中,批标准化便诞生了。(这也解释了为什么MNIST数据集里面的像素点都是标准化后的值)

实例

由于批标准化是为了加速神经网络训练速度的,所以要先理解是什么导致了神经网络训练缓慢的。《深度学习》中给出了一个线性的例子:


假设我们有一个深度神经网络,每一层只有一个单元,并且在每个隐藏层不使用激励函数:
y ^ = x w 1 w 2 w 3 . . … . w l \hat{y}=xw_1w_2w_3..….w_l y^=xw1w2w3...wl

这里, w i w_i wi表示用于层 i i i 的权重。层 i i i 的输出是 h i = h i − 1 w i h_i=h_{i-1}w_i hi=hi1wi。输出g是输入x的线性函数,却是权重 w i w_i wi的非线性函数。假设我们的代价函数g上的梯度为1,所以我们希望稍稍降低。然后反向传播算法可以计算梯度 g = ∇ w y ^ g=\nabla_{w} \hat{y} g=wy^。想想我们在更新 w ← w − ϵ g w \leftarrow w-\epsilon g wwϵg时会发生什么。近似 y ^ \hat{y} y^的一阶泰勒级数会预测 g g g的值下降 ϵ g ⊤ g \epsilon g^{\top} g ϵgg。如果我们希望下降0.1,那么梯度中的一阶信息表明我们应设置学习速率 ϵ \epsilon ϵ 0.1 g ⊤ g \frac{0.1}{\boldsymbol{g}^{\top} \boldsymbol{g}} gg0.1。然而,实际的更新将包括二阶,三阶,直到 l l l 阶的影响。 y ^ \hat{y} y^ 的更新值为:
x ( w 1 − ϵ g 1 ) ( w 2 − ϵ g 2 ) … ( w l − ϵ g l ) x\left(w_{1}-\epsilon g_{1}\right)\left(w_{2}-\epsilon g_{2}\right) \ldots\left(w_{l}-\epsilon g_{l}\right) x(w1ϵg1)(w2ϵg2)(wlϵgl)


在以上这种较为极端的情况下,更新值也很可能会达到指数级大小。我们很难指定合适的学习率,因为某一层的更新效果很大程度上取决于其他所有层。(因为层 l l l 更新式子中包含了所有 l − 1 l-1 l1层级的更新值。如果学习率指定为 0.1 g ⊤ g \frac{0.1}{\boldsymbol{g}^{\top} \boldsymbol{g}} gg0.1,那么我们无法兼顾到所有层,有些层计算的结果大于1将导致最终结果呈指数级大小,有些层计算结果小于1可能被忽略)

其实该例子主要表达了由于深层神经网络在做非线性变换前的激活输入值(就是那个 y = W T X + b y=W^TX+b y=WTX+b,X是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近(对于Sigmoid函数来说,意味着激活输入值 W T X + b W^TX+b WTX+b是大的负值或正值),所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。

通过BN,梯度不会再简单的增加 h i h_i hi 的标准差或均值!每一层的输入数据都会位于激活函数中间区间,梯度不会消失。

上面的解释可能会过于抽象,下面结合书中的图例进行讲解。如图所示,这是代价函数的等高线图,我们可以发现该函数是病态的并呈现椭圆状,梯度下降策略并不能有效的下降到谷底,而是沿着陡峭的方向来回震荡。
在这里插入图片描述
而BN方法可以有效的解决该问题。通过批标准化,我们的代价函数图像将呈现单位圆状,梯度下降策略能够有效的下降到谷底。
并且解决该病态条件也可以通过调整每次前进的步伐方向来解决,比如带动量的SGD最速下降法(共轭梯度法)等。

回到上面的例子,我们通过标准化 h l − 1 h_{l-1} hl1可以有效的消除所有底层参数的影响—— h l − 1 h_{l-1} hl1始终都是具有0均值1方差的标准正态分布。这样我们的模型会变得很容易去学习。如果没有标准化,那么几乎每一个更新都会对 h l − 1 h_{l-1} hl1的统计量有着极端的影响。

但是很明显,看到这里,稍微了解神经网络的读者一般会提出一个疑问:如果都通过BN,那么不就跟把非线性函数替换成线性函数效果相同了?这意味着什么?我们知道,如果是多层的线性函数变换其实这个深层是没有意义的,因为多层线性网络跟一层线性网络是等价的。这意味着网络的表达能力下降了,这也意味着深度的意义就没有了。所以BN为了保证非线性的获得,对变换后的满足均值为0方差为1的x又进行了scale加上shift操作(y=scale*x+shift),每个神经元增加了两个参数scale和shift参数,这两个参数是通过训练学习到的,意思是通过scale和shift把这个值从标准正态分布左移或者右移一点并长胖一点或者变瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区动了动。

核心思想应该是想找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处,又避免太靠非线性区两头使得网络收敛速度太慢。

书中也指出BN操作其实类似于正则化:L2正则化其实是通过增加对权重参数的惩罚从而鼓励单元标准化激活统计量。例如书中的示例图(L2正则化效果),我们可以明显看出本来的病态条件被转换为了易于学习的状态。
在这里插入图片描述
更比如dropout的正则化防止过拟合,其实BN也具有类似效果。

实际应用

我们要知道BN是基于MINI-Batch的,即小批量梯度下降算法。由于《深度学习》中推荐的用法是在经历过激活函数之后使用BN操作,所以我们可以把BN操作想象成一个层级结构(类似于池化层一样的东东)。
对于每一个输入实例 x k x_k xk,我们进行如下变换:
x ^ ( k ) = x ( k ) − E [ x ( k ) ] Var ⁡ [ x ( k ) ] \hat{x}^{(k)}=\frac{x^{(k)}-E\left[x^{(k)}\right]}{\sqrt{\operatorname{Var}\left[x^{(k)}\right]}} x^(k)=Var[x(k)] x(k)E[x(k)]

继续添加:
y ( k ) = γ ( k ) x ^ ( k ) + β ( k ) y^{(k)}=\gamma^{(k)} \hat{x}^{(k)}+\beta^{(k)} y(k)=γ(k)x^(k)+β(k)

为了简化操作,我们使用tensorflow的相关方法来实现一下。tensorflow中关于BN的函数主要有两个,分别是:

  • tf.nn.moments(x, axes, name=None, keep_dims=False)
    该函数主要求输入x的均值与方差,返回两个tensor,分别为mean, variance。
  • tf.nn.batch_normalization()
    该函数接收参数不一定,此处,以layers中的封装为例,以MNIST作为数据集编写示例程序如下:
import tensorflow as tf
import os
from tensorflow.examples.tutorials.mnist import input_data

tf.logging.set_verbosity(tf.logging.INFO)

if __name__ == '__main__':
    mnist = input_data.read_data_sets('mnist', one_hot=True)
    x = tf.placeholder(tf.float32, [None, 784])
    y_ = tf.placeholder(tf.float32, [None, 10])
    image = tf.reshape(x, [-1, 28, 28, 1])
    conv1 = tf.layers.conv2d(image, filters=32, kernel_size=[3, 3], strides=[1, 1], padding='same',
                             activation=tf.nn.relu,
                             kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),
                             name='conv1')
    bn1 = tf.layers.batch_normalization(conv1, training=True, name='bn1')
    pool1 = tf.layers.max_pooling2d(bn1, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool1')
    conv2 = tf.layers.conv2d(pool1, filters=64, kernel_size=[3, 3], strides=[1, 1], padding='same',
                             activation=tf.nn.relu,
                             kernel_initializer=tf.truncated_normal_initializer(stddev=0.1),
                             name='conv2')
    bn2 = tf.layers.batch_normalization(conv2, training=True, name='bn2')
    pool2 = tf.layers.max_pooling2d(bn2, pool_size=[2, 2], strides=[2, 2], padding='same', name='pool2')

    flatten_layer = tf.contrib.layers.flatten(pool2, 'flatten_layer')
    weights = tf.get_variable(shape=[flatten_layer.shape[-1], 10], dtype=tf.float32,
                              initializer=tf.truncated_normal_initializer(stddev=0.1), name='fc_weights')
    biases = tf.get_variable(shape=[10], dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0), name='fc_biases')
    logit_output = tf.nn.bias_add(tf.matmul(flatten_layer, weights), biases, name='logit_output')
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logit_output))
    pred_label = tf.argmax(logit_output, 1)
    label = tf.argmax(y_, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(pred_label, label), tf.float32))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    global_step = tf.get_variable('global_step', [], dtype=tf.int32,
                                  initializer=tf.constant_initializer(0), trainable=False)
    learning_rate = tf.train.exponential_decay(learning_rate=0.1, global_step=global_step, decay_steps=5000,
                                               decay_rate=0.1, staircase=True)
    opt = tf.train.AdadeltaOptimizer(learning_rate=learning_rate, name='optimizer')
    with tf.control_dependencies(update_ops):
        grads = opt.compute_gradients(cross_entropy)
        train_op = opt.apply_gradients(grads, global_step=global_step)

    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    tf_config.allow_soft_placement = True
    sess = tf.InteractiveSession(config=tf_config)
    sess.run(tf.global_variables_initializer())

    # only save trainable and bn variables
    var_list = tf.trainable_variables()
    if global_step is not None:
        var_list.append(global_step)
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars
    saver = tf.train.Saver(var_list=var_list,max_to_keep=5)
    # save all variables
    # saver = tf.train.Saver(max_to_keep=5)

    if tf.train.latest_checkpoint('ckpts') is not None:
        saver.restore(sess, tf.train.latest_checkpoint('ckpts'))
    train_loops = 10000
    for i in range(train_loops):
        batch_xs, batch_ys = mnist.train.next_batch(32)
        _, step, loss, acc = sess.run([train_op, global_step, cross_entropy, accuracy],
                                      feed_dict={x: batch_xs, y_: batch_ys})
        if step % 100 == 0:  # print training info
            log_str = 'step:%d \t loss:%.6f \t acc:%.6f' % (step, loss, acc)
            tf.logging.info(log_str)
        if step % 1000 == 0:  # save current model
            save_path = os.path.join('ckpts', 'mnist-model.ckpt')
            saver.save(sess, save_path, global_step=step)

    sess.close()

为了更加清晰的实现带BN结构的MNIST数据训练,下篇博客将以类的形式来实现神经网络结构。

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wangbowj123

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值