<article class="baidu_pl">
<div id="article_content" class="article_content clearfix csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
<div id="content_views" class="markdown_views prism-atom-one-light">
<!-- flowchart 箭头图标 勿删 -->
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg>
<p>笔者近来在tensorflow中使用batch_norm时,由于事先不熟悉其内部的原理,因此将其错误使用,从而出现了结果与预想不一致的结果。事后对其进行了一定的调查与研究,在此进行一些总结。</p>
<h2><a name="t0"></a><a id="_2" target="_blank"></a>一、错误使用及结果</h2>
<p>笔者最先使用时只是了解到了在tensorflow中<code>tf.layers.batch_normalization</code>这个函数,就在函数中直接将其使用,该函数中有一个参数为<code>training</code>,在训练阶段赋值<code>True</code>,在测试阶段赋值<code>False</code>。但是在训练完成后,出现了奇怪的现象时,在<code>training</code>赋值为<code>True</code>时,测试的正确率正常,但是<code>training</code>赋值为<code>False</code>时,测试正确率就很低。上述错误使用过程可以精简为下列代码段</p>
<pre class="prettyprint"><code class="has-numbering">is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)
loss = ...
train_op = optimizer.minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(train_op)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li></ul></pre>
<h2><a name="t1"></a><a id="batch_normalization_17" target="_blank"></a>二、batch_normalization</h2>
<p>下面首先粗略的介绍一下batch_normalization,这种归一化方法的示意图和算法如下图,<br>
<img src="https://img-blog.csdnimg.cn/20181215150445911.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2h1aXRhaWxhbmd5eg==,size_16,color_FFFFFF,t_70" alt="在这里插入图片描述"><img src="https://img-blog.csdnimg.cn/20181215150247465.png" alt="在这里插入图片描述"><br>
总的来说就是对于同一batch的input,假设输入大小为[batch_num, height, width, channel],逐channel地计算同一batch中所有数据的mean和variance,再对input使用mean和variance进行归一化,最后的输出再进行线性平移,得到batch_norm的最终结果。伪代码如下:</p>
<pre class="prettyprint"><code class="has-numbering">for i in range(channel):
x = input[:,:,:,i]
mean = mean(x)
variance = variance(x)
x = (x - mean) / sqrt(variance)
x = scale * x + offset
input[:,:,:,i] = x
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li></ul></pre>
<p>在实现的时候,会在训练阶段记录下训练数据中平均mean和variance,记为moving_mean和moving_variance,并在测试阶段使用训练时的moving_mean和moving_variance进行计算,这也就是参数<code>training</code>的作用。另外,在实现时一般使用一个decay系数来逐步更新moving_mean和moving_variance,<code>moving_mean = moving_mean * decay + new_batch_mean * (1 - decay)</code></p>
<h2><a name="t2"></a><a id="tensorflow_34" target="_blank"></a>三、tensorflow中的三种实现</h2>
<p>tensorflow中关于batch_norm现在有三种实现方式。</p>
<h3><a name="t3"></a><a id="1tfnnbatch_normalization_36" target="_blank"></a>1、tf.nn.batch_normalization(最底层的实现)</h3>
<pre class="prettyprint"><code class="has-numbering">tf.nn.batch_normalization(
x,
mean,
variance,
offset,
scale,
variance_epsilon,
name=None
)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li></ul></pre>
<p>该函数是一种最底层的实现方法,在使用时mean、variance、scale、offset等参数需要自己传递并更新,因此实际使用时还需自己对该函数进行封装,一般不建议使用,但是对了解batch_norm的原理很有帮助。<br>
封装使用的实例如下:</p>
<pre class="prettyprint"><code class="has-numbering">import tensorflow as tf
def batch_norm(x, name_scope, training, epsilon=1e-3, decay=0.99):
""" Assume nd [batch, N1, N2, ..., Nm, Channel] tensor"""
with tf.variable_scope(name_scope):
size = x.get_shape().as_list()[-1]
scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(0.1))
offset = tf.get_variable('offset', [size])
pop_mean = tf.get_variable('pop_mean', [size], initializer=tf.zeros_initializer(), trainable=False)
pop_var = tf.get_variable('pop_var', [size], initializer=tf.ones_initializer(), trainable=False)
batch_mean, batch_var = tf.nn.moments(x, list(range(len(x.get_shape())-1)))
train_mean_op = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var_op = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
def batch_statistics():
with tf.control_dependencies([train_mean_op, train_var_op]):
return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)
def population_statistics():
return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)
return tf.cond(training, batch_statistics, population_statistics)
is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = batch_norm(input, name_scope='batch_norm_nn', training=is_traing)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, "batch_norm_nn/Model")
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li><li style="color: rgb(153, 153, 153);">23</li><li style="color: rgb(153, 153, 153);">24</li><li style="color: rgb(153, 153, 153);">25</li><li style="color: rgb(153, 153, 153);">26</li><li style="color: rgb(153, 153, 153);">27</li><li style="color: rgb(153, 153, 153);">28</li><li style="color: rgb(153, 153, 153);">29</li><li style="color: rgb(153, 153, 153);">30</li><li style="color: rgb(153, 153, 153);">31</li></ul></pre>
<p>在<code>batch_norm</code>中,首先先计算了x的逐通道的mean和var,然后将pop_mean和pop_var进行更新,并根据是在训练阶段还是测试阶段选择将当批次计算的mean和var或者训练阶段保存的mean和var与新定义的变量scale和offset一起传递给<code>tf.nn.batch_normalization</code></p>
<h3><a name="t4"></a><a id="2tflayersbatch_normalization_87" target="_blank"></a>2、tf.layers.batch_normalization</h3>
<pre class="prettyprint"><code class="has-numbering">tf.layers.batch_normalization(
inputs,
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer=tf.zeros_initializer(),
gamma_initializer=tf.ones_initializer(),
moving_mean_initializer=tf.zeros_initializer(),
moving_variance_initializer=tf.ones_initializer(),
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
training=False,
trainable=True,
name=None,
reuse=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
fused=None,
virtual_batch_size=None,
adjustment=None
)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li><li style="color: rgb(153, 153, 153);">23</li><li style="color: rgb(153, 153, 153);">24</li><li style="color: rgb(153, 153, 153);">25</li><li style="color: rgb(153, 153, 153);">26</li></ul></pre>
<p>该函数也就是笔者之前使用的函数,在官方文档中写道</p>
<blockquote>
<p>Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. Also, be sure to add any batch_normalization ops before getting the update_ops collection. Otherwise, update_ops will be empty, and training/inference will not work properly. For example:</p>
</blockquote>
<pre class="prettyprint"><code class="has-numbering"> x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li></ul></pre>
<p>可以看到,与笔者之前的错误实现方法的差异主要在</p>
<pre class="prettyprint"><code class="has-numbering">update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li></ul></pre>
<p>这两句话,同时可以看到在第一种方法tf.nn.batch_normalization的封装过程中也用到了类似的处理方法,具体会在下一段进行说明。</p>
<h3><a name="t5"></a><a id="3tfcontriblayersbatch_normslim_138" target="_blank"></a>3、tf.contrib.layers.batch_norm(slim)</h3>
<pre class="prettyprint"><code class="has-numbering">tf.contrib.layers.batch_norm(
inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
param_initializers=None,
param_regularizers=None,
updates_collections=tf.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=None,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99,
adjustment=None
)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li><li style="color: rgb(153, 153, 153);">23</li><li style="color: rgb(153, 153, 153);">24</li><li style="color: rgb(153, 153, 153);">25</li></ul></pre>
<p>这种方法与<code>tf.layers.batch_normalization</code>的使用方法差不多,两者最主要的差别在参数<code>scale</code>和<code>centre</code>的默认值上,这两个参数即是我们之前介绍原理时所说明的对input进行mean和variance的归一化之后采用的线性平移中的<code>scale</code>和<code>offset</code>,可以看到<code>offset</code>的默认值两者都是<code>True</code>,但是<code>scale</code>的默认值前者为<code>True</code>后者为<code>False</code>,也就是说明在<code>tf.contrib.layers.batch_norm</code>中,默认不对处理后的input进行线性缩放,只是加一个偏移。</p>
<h2><a name="t6"></a><a id="tfGraphKeysUPDATA_OPS_169" target="_blank"></a>四、关于tf.GraphKeys.UPDATA_OPS</h2>
<p>介绍到这里,还有两个概念没有介绍,一个是<code>tf.GraphKeys.UPDATE_OPS</code>,另一个是<code>tf.control_dependencies</code>。</p>
<h3><a name="t7"></a><a id="1tfcontrol_dependencies_171" target="_blank"></a>1、tf.control_dependencies</h3>
<p>首先我们先介绍<code>tf.control_dependencies</code>,该函数保证其辖域中的操作必须要在该函数所传递的参数中的操作完成后再进行。请看下面一个例子。</p>
<pre class="prettyprint"><code class="has-numbering">import tensorflow as tf
a_1 = tf.Variable(1)
b_1 = tf.Variable(2)
update_op = tf.assign(a_1, 10)
add = tf.add(a_1, b_1)
a_2 = tf.Variable(1)
b_2 = tf.Variable(2)
update_op = tf.assign(a_2, 10)
with tf.control_dependencies([update_op]):
add_with_dependencies = tf.add(a_2, b_2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans_1, ans_2 = sess.run([add, add_with_dependencies])
print("Add: ", ans_1)
print("Add_with_dependency: ", ans_2)
输出:
Add: 3
Add_with_dependency: 12
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li></ul></pre>
<p>可以看到两组加法进行的对比,正常的计算图在计算<code>add</code>时是不会经过<code>update_op</code>操作的,因此在加法时<code>a</code>的值为1,但是采用<code>tf.control_dependencies</code>函数,可以控制在进行<code>add</code>前先完成<code>update_op</code>的操作,因此在加法时<code>a</code>的值为10,因此最后两种加法的结果不同。</p>
<h3><a name="t8"></a><a id="2tfGraphKeysUPDATE_OPS_198" target="_blank"></a>2、tf.GraphKeys.UPDATE_OPS</h3>
<p>关于<code>tf.GraphKeys.UPDATE_OPS</code>,这是一个tensorflow的计算图中内置的一个集合,其中会保存一些需要在训练操作之前完成的操作,并配合<code>tf.control_dependencies</code>函数使用。<br>
关于在batch_norm中,即为更新mean和variance的操作。通过下面一个例子可以看到<code>tf.layers.batch_normalization</code>中是如何实现的。</p>
<pre class="prettyprint"><code class="has-numbering">import tensorflow as tf
is_traing = tf.placeholder(dtype=tf.bool)
input = tf.ones([1, 2, 2, 3])
output = tf.layers.batch_normalization(input, training=is_traing)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
print(update_ops)
# with tf.control_dependencies(update_ops):
# train_op = optimizer.minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, "batch_norm_layer/Model")
输出:
[<tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(3,) dtype=float32_ref>, <tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(3,) dtype=float32_ref>]
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li></ul></pre>
<p>可以看到输出的即为两个batch_normalization中更新mean和variance的操作,需要保证它们在train_op前完成。<br>
这两个操作是在tensorflow的内部实现中自动被加入<code>tf.GraphKeys.UPDATE_OPS</code>这个集合的,在<code>tf.contrib.layers.batch_norm</code>的参数中可以看到有一项<code>updates_collections</code>的默认值即为<code>tf.GraphKeys.UPDATE_OPS</code>,而在<code>tf.layers.batch_normalization</code>中则是直接将两个更新操作放入了上述集合。</p>
<h2><a name="t9"></a><a id="_226" target="_blank"></a>五、关于最初的错误使用的思考</h2>
<p>最后我对于一开始的使用方法为什么会导致错误进行了思考,tensorflow中具体实现batch_normalization的代码在<code>tensorflow\python\layers\normalization.py</code>中,下面展示一些关键代码。</p>
<pre class="prettyprint"><code class="has-numbering">if self.scale:
self.gamma = self.add_variable(
name='gamma',
shape=param_shape,
dtype=param_dtype,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
trainable=True)
else:
self.gamma = None
if self.center:
self.beta = self.add_variable(
name='beta',
shape=param_shape,
dtype=param_dtype,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
trainable=True)
else:
self.beta = None
scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
self.moving_mean = self._add_tower_local_variable(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self._add_tower_local_variable(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
trainable=False)
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg', [variable, value, momentum]) as scope:
decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = (variable - value) * decay
return state_ops.assign_sub(variable, update_delta, name=scope)
def _do_update(var, value):
return self._assign_moving_average(var, value, self.momentum)
# Determine a boolean value for `training`: could be True, False, or None.
training_value = utils.constant_value(training)
if training_value is not False:
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
moving_mean = self.moving_mean
moving_variance = self.moving_variance
mean = utils.smart_cond(training,
lambda: mean,
lambda: moving_mean)
variance = utils.smart_cond(training,
lambda: variance,
lambda: moving_variance)
else:
new_mean, new_variance = mean, variance
mean_update = utils.smart_cond(
training,
lambda: _do_update(self.moving_mean, new_mean),
lambda: self.moving_mean)
variance_update = utils.smart_cond(
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance)
if not context.executing_eagerly():
self.add_update(mean_update, inputs=inputs)
self.add_update(variance_update, inputs=inputs)
outputs = nn.batch_normalization(inputs,
_broadcast(mean),
_broadcast(variance),
offset,
scale,
self.epsilon)
</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li><li style="color: rgb(153, 153, 153);">16</li><li style="color: rgb(153, 153, 153);">17</li><li style="color: rgb(153, 153, 153);">18</li><li style="color: rgb(153, 153, 153);">19</li><li style="color: rgb(153, 153, 153);">20</li><li style="color: rgb(153, 153, 153);">21</li><li style="color: rgb(153, 153, 153);">22</li><li style="color: rgb(153, 153, 153);">23</li><li style="color: rgb(153, 153, 153);">24</li><li style="color: rgb(153, 153, 153);">25</li><li style="color: rgb(153, 153, 153);">26</li><li style="color: rgb(153, 153, 153);">27</li><li style="color: rgb(153, 153, 153);">28</li><li style="color: rgb(153, 153, 153);">29</li><li style="color: rgb(153, 153, 153);">30</li><li style="color: rgb(153, 153, 153);">31</li><li style="color: rgb(153, 153, 153);">32</li><li style="color: rgb(153, 153, 153);">33</li><li style="color: rgb(153, 153, 153);">34</li><li style="color: rgb(153, 153, 153);">35</li><li style="color: rgb(153, 153, 153);">36</li><li style="color: rgb(153, 153, 153);">37</li><li style="color: rgb(153, 153, 153);">38</li><li style="color: rgb(153, 153, 153);">39</li><li style="color: rgb(153, 153, 153);">40</li><li style="color: rgb(153, 153, 153);">41</li><li style="color: rgb(153, 153, 153);">42</li><li style="color: rgb(153, 153, 153);">43</li><li style="color: rgb(153, 153, 153);">44</li><li style="color: rgb(153, 153, 153);">45</li><li style="color: rgb(153, 153, 153);">46</li><li style="color: rgb(153, 153, 153);">47</li><li style="color: rgb(153, 153, 153);">48</li><li style="color: rgb(153, 153, 153);">49</li><li style="color: rgb(153, 153, 153);">50</li><li style="color: rgb(153, 153, 153);">51</li><li style="color: rgb(153, 153, 153);">52</li><li style="color: rgb(153, 153, 153);">53</li><li style="color: rgb(153, 153, 153);">54</li><li style="color: rgb(153, 153, 153);">55</li><li style="color: rgb(153, 153, 153);">56</li><li style="color: rgb(153, 153, 153);">57</li><li style="color: rgb(153, 153, 153);">58</li><li style="color: rgb(153, 153, 153);">59</li><li style="color: rgb(153, 153, 153);">60</li><li style="color: rgb(153, 153, 153);">61</li><li style="color: rgb(153, 153, 153);">62</li><li style="color: rgb(153, 153, 153);">63</li><li style="color: rgb(153, 153, 153);">64</li><li style="color: rgb(153, 153, 153);">65</li><li style="color: rgb(153, 153, 153);">66</li><li style="color: rgb(153, 153, 153);">67</li><li style="color: rgb(153, 153, 153);">68</li><li style="color: rgb(153, 153, 153);">69</li><li style="color: rgb(153, 153, 153);">70</li><li style="color: rgb(153, 153, 153);">71</li><li style="color: rgb(153, 153, 153);">72</li><li style="color: rgb(153, 153, 153);">73</li><li style="color: rgb(153, 153, 153);">74</li><li style="color: rgb(153, 153, 153);">75</li><li style="color: rgb(153, 153, 153);">76</li><li style="color: rgb(153, 153, 153);">77</li><li style="color: rgb(153, 153, 153);">78</li><li style="color: rgb(153, 153, 153);">79</li><li style="color: rgb(153, 153, 153);">80</li><li style="color: rgb(153, 153, 153);">81</li><li style="color: rgb(153, 153, 153);">82</li><li style="color: rgb(153, 153, 153);">83</li><li style="color: rgb(153, 153, 153);">84</li><li style="color: rgb(153, 153, 153);">85</li></ul></pre>
<p>可以看到其内部逻辑和我在介绍<code>tf.nn.batch_normalization</code>一节中展示的封装时所使用的方法类似。<br>
如果不在使用时添加<code>tf.control_dependencies</code>函数,即在训练时(<code>training=True</code>)每批次时只会计算当批次的mean和var,并传递给<code>tf.nn.batch_normalization</code>进行归一化,由于mean_update和variance_update在计算图中并不在上述操作的依赖路径上,因为并不会主动完成,也就是说,在训练时mean_update和variance_update并不会被使用到,其值一直是初始值。因此在测试阶段(<code>training=False</code>)使用这两个作为mean和variance并进行归一化操作,这样就会出现错误。而如果使用<code>tf.control_dependencies</code>函数,会在训练阶段每次训练操作执行前被动地去执行mean_update和variance_update,因此moving_mean和moving_variance会被不断更新,在测试时使用该参数也就不会出现错误。</p>
</div>
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-2b43bc2447.css" rel="stylesheet">
</div>
<script>
(function(){
function setArticleH(btnReadmore,posi){
var winH = $(window).height();
var articleBox = $("div.article_content");
var artH = articleBox.height();
if(artH > winH*posi){
articleBox.css({
'height':winH*posi+'px',
'overflow':'hidden'
})
btnReadmore.click(function(){
if(typeof window.localStorage === "object" && typeof window.csdn.anonymousUserLimit === "object"){
if(!window.csdn.anonymousUserLimit.judgment()){
window.csdn.anonymousUserLimit.Jumplogin();
return false;
}else if(!currentUserName){
window.csdn.anonymousUserLimit.updata();
}
}
articleBox.removeAttr("style");
$(this).parent().remove();
})
}else{
btnReadmore.parent().remove();
}
}
var btnReadmore = $("#btn-readmore");
if(btnReadmore.length>0){
if(currentUserName){
setArticleH(btnReadmore,3);
}else{
setArticleH(btnReadmore,1.2);
}
}
})()
</script>
</article>