BN实现

这样应该是最接近我对论文的理解写出的bn代码,如果有问题,欢迎指正。


def batch_norm(x, n_out,train, eps=1e-05, decay=0.99,affine=True, name=None):
    with tf.variable_scope(name, default_name='BatchNorm2d'):
      moving_mean = tf.get_variable('mean', [n_out],
                                      initializer=tf.zeros_initializer,
                                      trainable=False)
      moving_variance = tf.get_variable('variance', [n_out],
                                          initializer=tf.ones_initializer,
                                          trainable=False)

      train=tf.convert_to_tensor(train)

      def mean_var_with_update():
        mean, variance = tf.nn.moments(x, [0,1,2], name='moments')
        # 计算train时的moving average用于inference。
        from tensorflow.python.training.moving_averages import assign_moving_average
        with tf.control_dependencies([assign_moving_average(moving_mean, mean, decay),
                                  assign_moving_average(moving_variance, variance, decay)]):
            return tf.identity(mean), tf.identity(variance)
      #train=True时返回batch——mean/var,
      # Train=False时返回moving_mean和moving_variance,这个已经在train的时候更新过了。
      mean, variance = tf.cond(train, mean_var_with_update, lambda: (moving_mean, moving_variance))
      if affine:
            beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
                               name='beta', trainable=True)
            gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
                                name='gamma', trainable=True)
            x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, eps)
      else:
            x = tf.nn.batch_normalization(x, mean, variance, None, None, eps)
      return x
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值