【Tensorflow 2.0 正式版教程】Batch normalization实现

上一篇文章中讲解了如何实现自定义层,现在我们来实现一个非常特殊且重要的网络层:Batch normalization(批标准化层)。

该层的特殊之处是其在训练和测试阶段有着不同的行为,具体来说,该层需要计算出输入batch的均值和方差,然后对该batch的数据进行归一化,即减去均值再除以方差。训练阶段显然可以直接计算,但在测试阶段,每个样本间是没有关系的,即不应该存在batch的概念,那么只能依靠训练时的数据的统计结果代替测试阶段的均值与方差。那么代码可以这样写

def call(self, inputs, training):
    if training:
        batch_mean, batch_variance = tf.nn.moments(inputs, list(range(len(inputs.shape) - 1)))
        self.moving_mean = self.moving_mean * self.decay + batch_mean * (1 - self.decay)
        self.moving_variance = self.moving_variance * self.decay + batch_variance * (1 - self.decay)
        mean, variance = batch_mean, batch_variance
    else:
        mean, variance = self.moving_mean, self.moving_variance

这在eager模式下调试时没有问题,但是在真正训练时会报错,原因在于静态图无法处理这样的分支情况,原理这里不细讲了,我们需要利用Layeradd_update()方法和variable.assign()进行实现。如下

def assign_moving_average(self, variable, value):
    """
    variable = variable * decay + value * (1 - decay)
    """
    delta = variable * self.decay + value * (1 - self.decay)
    return variable.assign(delta)

@tf.function
def call(self, inputs, training):
    if training:
        batch_mean, batch_variance = tf.nn.moments(inputs, list(range(len(inputs.shape) - 1)))
        mean_update = self.assign_moving_average(self.moving_mean, batch_mean)
        variance_update = self.assign_moving_average(self.moving_variance, batch_variance)
        self.add_update(mean_update)
        self.add_update(variance_update)
        mean, variance = batch_mean, batch_variance
    else:
        mean, variance = self.moving_mean, self.moving_variance
    output = tf.nn.batch_normalization(inputs,
                                       mean=mean,
                                       variance=variance,
                                       offset=self.beta,
                                       scale=self.gamma,
                                       variance_epsilon=self.epsilon)
    return output

在tensorflow源码的注释中,作者的意思是add_update()方法几乎是为了batch normalization而量身定做的,其他的标准化层如Layer normalizationInstance Normalization都不涉及batch的操作,从而实现起来非常简单。

完整的实现代码可以在我的github上找到
https://github.com/Apm5/tensorflow_2.0_tutorial/blob/master/CNN/BatchNormalization.py

我的实现中计算方法是正确的,但缺乏进一步的优化,计算速度不如官方实现。实际应用中还是建议直接使用tf.keras.layers.BatchNormalization

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值