tensorflow-利用batch_normalization进行标准化

batch_normalization:批标准化,即对整个神经训练的批次进行标准化处理,使其取值位于激励函数的非饱和范围,批标准化不仅只对于输入层,对于所有的隐藏层(需要激励函数)均进行处理,有点是加速训练效率,防止细胞死亡。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

tf.set_random_seed(1)
np.random.seed(1)

# 训练参数
N_SAMPLES = 2000
BATCH_SIZE = 64
EPOCH = 12
LR = 0.03
N_HIDDEN = 8
ACTIVATION = tf.nn.tanh
B_INIT = tf.constant_initializer(-0.2)      # use a bad bias initialization

# 训练数据 y = x**2-5+noise
x = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]
np.random.shuffle(x)
noise = np.random.normal(0, 2, x.shape)
y = np.square(x) - 5 + noise
train_data = np.hstack((x, y))

# 测试数据
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noise

# 数据可视化
plt.scatter(x, y, c='#FF9359', s=50, alpha=0.5, label='train')
plt.legend(loc='upper left')

# 生成variable和tensor
tf_x = tf.placeholder(tf.float32, [None, 1])
tf_y = tf.placeholder(tf.float32, [None, 1])
tf_is_train = tf.placeholder(tf.bool, None)     # 标识:识别是否需要使用批标准化


class NN(object):
    def __init__(self, batch_normalization=False):
        self.is_bn = batch_normalization

        self.w_init = tf.random_normal_initializer(0., .1)  
        self.pre_activation = [tf_x]
        if self.is_bn:
#-----------addition:这里采用了最新的api:tf.layers.batch_normalization<===>原用tf.train.moment()求平均值和方差
#------------------------------------------------------------------------x_re = (x-mean_x)/sqrt(var_x**2+sigma**2  ----标准化
#------------------------------------------------------------------------x_nor = W*x_re+b  -------最后反标准化,用于神经网络训练
            self.layer_input = [tf.layers.batch_normalization(tf_x, training=tf_is_train)]  # for input data
        else:
            self.layer_input = [tf_x]
#-------增加隐藏层
        for i in range(N_HIDDEN): 
            self.layer_input.append(self.add_layer(self.layer_input[-1], 10, ac=ACTIVATION))
#-------定义输出层和损失函数
        self.out = tf.layers.dense(self.layer_input[-1], 1, kernel_initializer=self.w_init, bias_initializer=B_INIT)
        self.loss = tf.losses.mean_squared_error(tf_y, self.out)

#update_ops存放了计算图中,需要更新的变量集合,tf.GraphKeys.UPDATE_OPS存放着需要更新的变量名称,该变量创建是由tf.layers.batch_normalization创建:
#训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作被放入tf.GraphKeys.UPDATE_OPS,因此需要将它们作为依赖项添加到train_op。
#此外,在获取update_ops集合之前,请务必添加任何batch_normalization操作。否则,update_ops将为空,并且训练/推断将无法正常工作。
#将该集合传入一个单独的上下文控制器(control_dependencies),其参数为控制依赖,即在update_ops对象都计算完成后,才进行下面的op
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            self.train = tf.train.AdamOptimizer(LR).minimize(self.loss)

    def add_layer(self, x, out_size, ac=None):
        x = tf.layers.dense(x, out_size, kernel_initializer=self.w_init, bias_initializer=B_INIT)
        self.pre_activation.append(x)
# tf.layers.batch_normalization:参数:inputs:输入的tensor;axis:标准化的方向,0代表第一个维度(行),1代表第二个维度(列),momentum:移动平均值的动量参数
#----------------------------------- epsilon:防止除0操作产生时候设置的最小数(max(epsilon,div_num));center:bool量,true增加偏移量;
#----------------------------------- scale:bool量,true会将输出值乘与变量gammar,一般作为隐藏层不采用,输出层一般会采用。
#----------------------------------- trainable:bool量,true,将标准化的变量加入graph的collection里;
#----------------------------------- training:bool量,true时,表示使用变动的标准化参数(训练时),False,表示使用静态的标准化参数(测试或预测时)。
        if self.is_bn:
            x = tf.layers.batch_normalization(x, momentum=0.4, training=tf_is_train)    # 批标准化
        out = x if ac is None else ac(x)
        return out
#定义两个神经网络,一个采用了批标准化,一个没有,并对比效果
nets = [NN(batch_normalization=False), NN(batch_normalization=True)]

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# 绘图-输入数据的分布
f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))
plt.ion()   # 交互绘图模式

def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
    for i, (ax_pa, ax_pa_bn, ax,  ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
        [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
        if i == 0: p_range = (-7, 10); the_range = (-7, 10)
        else: p_range = (-4, 4); the_range = (-1, 1)
        ax_pa.set_title('L' + str(i))
        ax_pa.hist(pre_ac[i].ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5)
        ax_pa_bn.hist(pre_ac_bn[i].ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
        ax.hist(l_in[i].ravel(), bins=10, range=the_range, color='#FF9359')
        ax_bn.hist(l_in_bn[i].ravel(), bins=10, range=the_range, color='#74BCFF')
        for a in [ax_pa, ax, ax_pa_bn, ax_bn]:
            a.set_yticks(()); a.set_xticks(())
        ax_pa_bn.set_xticks(p_range); ax_bn.set_xticks(the_range); axs[2, 0].set_ylabel('Act'); axs[3, 0].set_ylabel('BN Act')
    plt.pause(0.01)

losses = [[], []]   # record test loss
for epoch in range(EPOCH):
    print('Epoch: ', epoch)
    np.random.shuffle(train_data)
    step = 0
    in_epoch = True
    while in_epoch:
        b_s, b_f = (step*BATCH_SIZE) % len(train_data), ((step+1)*BATCH_SIZE) % len(train_data) # batch index
        step += 1
        if b_f < b_s:
            b_f = len(train_data)
            in_epoch = False
        b_x, b_y = train_data[b_s: b_f, 0:1], train_data[b_s: b_f, 1:2]         # batch training data
        sess.run([nets[0].train, nets[1].train], {tf_x: b_x, tf_y: b_y, tf_is_train: True})     # train

        if step == 1:
            l0, l1, l_in, l_in_bn, pa, pa_bn = sess.run(
                [nets[0].loss, nets[1].loss, nets[0].layer_input, nets[1].layer_input,
                 nets[0].pre_activation, nets[1].pre_activation],
                {tf_x: test_x, tf_y: test_y, tf_is_train: False})
            [loss.append(l) for loss, l in zip(losses, [l0, l1])]   # recode test loss
            plot_histogram(l_in, l_in_bn, pa, pa_bn)     # plot histogram

plt.ioff()

# plot test loss
plt.figure(2)
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
plt.ylabel('test loss'); plt.ylim((0, 2000)); plt.legend(loc='best')

# plot prediction line
pred, pred_bn = sess.run([nets[0].out, nets[1].out], {tf_x: test_x, tf_is_train: False})
plt.figure(3)
plt.plot(test_x, pred, c='#FF9359', lw=4, label='Original')
plt.plot(test_x, pred_bn, c='#74BCFF', lw=4, label='Batch Normalization')
plt.scatter(x[:200], y[:200], c='r', s=50, alpha=0.2, label='train')
plt.legend(loc='best'); plt.show()

输出结果:
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值