BatchNorm在TensorFlow中的应用

x = tf.nn.batch_normalization(x, mean, variance, beta, gama, BN_EPSILON)

x为输入数据,mean为批量数据x的均值,variance为批量数据x的方差(注意均值,方差为每一个维度求均值,方差),beta和gama分别为可学习的平移参数和缩放参数,BN_EPSILON防止方差为0(通常设为0.001)。

完整的bn函数如下

def bn(x, use_bn, is_training):
  x_shape = x.get_shape()
  params_shape = x_shape[-1:]
  if not use_bn:
    bias = _get_variable('bias', params_shape, initializer=tf.zeros_initializer)
    retrun x + bias
  axis = list(range(len(x_shape) - 1))
  beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer)
  gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer)
  moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer, trainable=False)
  moving_variance = _get_variable(moving_variance, params_shape, initializer=tf.ones_initializer, trainable=False)

  # these ops will only be performed when training.
  mean, variance = tf.nn.moments(x, axis)
  update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
  update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
  tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
  mean, variance = control_flow_ops.cond(is_training, lambda:(mean, variance), lambda:(moving_mean, moving_variance))
  x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
  return x

1.下面简单介绍下里面所用的函数:

tf.nn.moments(x, axes, shift=None, name=None, keep_dims=False)计算x的均值和方差(当x的维度为[batch, height, width, depth]时,对于global normalization,axes = [0, 1, 2], 此时mean和variance的维度为[depth],对于simple batch normalization, axes = [0], 此时mean和variance的维度为[height, width, depth])

moving_averages.assign_moving_average(variable, value, decay, zero_debias=True, name=None)计算变量的滑动平均值,更新后变量的值为variable * decay + value * (1 - decay), 在本例中decay = BN_DECAY = 0.9997

tf.add_to_collection(name, value)把变量放入一个集合,集合的关键字为name

control_flow_ops.cond(pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None)返回true_fn如果pred为True否则返回false_fn。本例中用以控制训练和测试使用的均值和方差。注意pred不能是python bool

2.接下来简单介绍怎么更新bn的参数:

batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)

tf.get_collection(key, scope=None):从关键字为key的集合中取出全部的变量,如果scope不为None,则取出该集合中包含scope变量名的变量。

tf.group(*inputs, **kwargs):将一些operation或者变量group起来。




  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值