tensorflow中Batch Normalization的实现

tensorflow版本1.4

tensorflow目前还没实现完全封装好的Batch Normalization的实现,这里主要试着实现一下。
关于理论可参见《 解读Batch Normalization》

对于TensorFlow下的BN的实现,首先我们列举一下需要注意的事项:

  • (1)需要自动适应卷积层(batch_size*height*width*channel)和全连接层(batch_size*channel);
  • (2)需要能够分别处理Training和Testing的情况,Training时需要更新均值和方差,Testing时使用历史滑动平均得到的均值与方差,即需要提供is_training的标志位参数;
  • (3)最好提供滑动平均系数可调;
  • (4)BN的计算量较大,尽量提高存储与运算效率;
  • (5)需要注意alpha和beta参数可以被BP更新,而均值和方差通过计算得到;
  • (6)load模型时,历史均值、方差以及alpha和beta参数需要被正常加载;

最终的实现如下:

#coding=utf-8
# util.py 用于实现一些功能函数

import tensorflow as tf

# 实现Batch Normalization
def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
    # 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
    shape = x.shape
    assert len(shape) in [2,4]

    param_shape = shape[-1]
    with tf.variable_scope(name):
        # 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
        gamma = tf.get_variable('gamma',param_shape,initializer=tf.constant_initializer(1))
        beta  = tf.get_variable('beat', param_shape,initializer=tf.constant_initializer(0))

        # 计算当前整个batch的均值与方差
        axes = list(range(len(shape)-1))
        batch_mean, batch_var = tf.nn.moments(x,axes,name='moments')

        # 采用滑动平均更新均值与方差
        ema = tf.train.ExponentialMovingAverage(moving_decay)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean,batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
        mean, var = tf.cond(tf.equal(is_training,True),mean_var_with_update,
                lambda:(ema.average(batch_mean),ema.average(batch_var)))

        # 最后执行batch normalization
        return tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)

测试函数如下:

import util
import tensorflow as tf


# 注意bn_layer中滑动平均的操作导致该层只支持半精度、float32和float64类型变量
x = tf.constant([[1,2,3],[2,4,8],[3,9,27]],dtype=tf.float32)
y = util.bn_layer(x,True)

# 注意bn_layer中的一些操作必须被提前初始化
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('x = ',x.eval())
    print('y = ',y.eval())

结果输出如下:(证明我们的初步计算是正确的)

x =  [[  1.   2.   3.]
 [  2.   4.   8.]
 [  3.   9.  27.]]
y =  [[-1.22473562 -1.01904869 -0.93499756]
 [ 0.         -0.33968294 -0.45137817]
 [ 1.2247355   1.35873151  1.38637543]]

下面介绍一下,实现过程中遇到的一些函数:

tf.nn.moments

# 用于在指定维度计算均值与方差
tf.nn.moments(
    x,
    axes,
    shift=None,
    name=None,
    keep_dims=False
)
  • x: 输入Tensor
  • axes: int型Array,用于指定在哪些维度计算均值与方差。如果x是1-D向量且axes=[0] 那么该函数就是计算整个向量的均值与方差
  • shift: 暂时无用

tf.train.ExponentialMovingAverage

# 类,用于计算滑动平均
tf.train.ExponentialMovingAverage

__init__(
    decay,
    num_updates=None,
    zero_debias=False,
    name='ExponentialMovingAverage'
)

具体的滑动公式如下,等价于一种指数衰减:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

tf.control_dependencies

# tf.control_dependencies(control_inputs)返回一个控制依赖的上下文管理器,
# 使用with关键字可以让在这个上下文环境中的操作都在control_inputs之后执行
# 比如:
with tf.control_dependencies([a, b]):
  # 只有在a和b执行完后,c和d才会被执行
  c = ...
  d = ...

tf.cond

# 用于有条件的执行函数,当pred为True时,执行true_fn函数,否则执行false_fn函数
tf.cond(
    pred,
    true_fn=None,
    false_fn=None,
    strict=False,
    name=None,
    fn1=None,
    fn2=None
)

尤其需要注意的是,pred参数是tf.bool型变量,直接写“True”或者“False”是python型bool,会报错的。因此在我的BN实现中使用了tf.equal(is_training,True)的操作。

tf.nn.batch_normalization

# 用于最中执行batch normalization的函数
tf.nn.batch_normalization(
    x,
    mean,
    variance,
    offset,
    scale,
    variance_epsilon,
    name=None
)

计算公式为: y = scale*(x-mean)/var + offset

  • 10
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值