正确使用Tensorflow Batch_normalization

版权声明:转载请注明出处: https://blog.csdn.net/dongjbstrong/article/details/80447110
                                        <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-e2445db1a8.css">
                    <div class="htmledit_views">
            <p>题外话:tensorflow中,tf.contrib是一些尝试代码或者不稳定的代码,将来可能会被弃用或者删除;tf.layers中是一些已经确定了的代码,基本不会有大的改动;tf.nn是最基本的代码,级别最低,基本不会变,但要实现各种功能,需要自己组合别的api来实现。</p><p><br></p><p>使用tf.layers.batch_normalization(),首先是api接口:<a href="https://tensorflow.google.cn/api_docs/python/tf/layers/batch_normalization" rel="nofollow" target="_blank">点击打开链接</a></p><p>注意:</p><p><strong>训练:</strong></p><p><strong>1.设置training=True;</strong></p><p><strong>2.添加下面的代码,以保存滑动平均值,测试时会使用到;</strong></p><p><span style="color:rgb(2,136,209);font-family:Roboto, sans-serif;font-size:14px;">Note:</span><span style="color:rgb(2,136,209);font-family:Roboto, sans-serif;font-size:14px;">&nbsp;when training, the </span><span style="font-family:Roboto, sans-serif;font-size:14px;"><span style="color:#ff0000;">moving_mean</span></span><span style="color:rgb(2,136,209);font-family:Roboto, sans-serif;font-size:14px;"> and </span><span style="font-family:Roboto, sans-serif;font-size:14px;"><span style="color:#ff6666;">moving_variance</span></span><span style="font-family:Roboto, sans-serif;font-size:14px;"><span style="color:#0288d1;"> need to be updated. By default the update ops are placed in&nbsp;</span><a href="https://tensorflow.google.cn/api_docs/python/tf/GraphKeys#UPDATE_OPS" rel="nofollow" style="background:rgb(225,245,254);" target="_blank"><code style="background:rgb(225,245,254);font-weight:700;font-size:12.6px;line-height:1;font-family:'Roboto Mono', monospace;"><span style="color:#ff0000;">tf.GraphKeys.UPDATE_OPS</span></code></a><span style="color:#0288d1;">, so they need to be added as a dependency to the&nbsp;</span><code style="background:rgb(225,245,254);font-weight:700;font-size:12.6px;line-height:1;font-family:'Roboto Mono', monospace;"><span style="color:#ff0000;">train_op</span></code><span style="color:#0288d1;">. Also, </span><span style="color:#ff0000;">be sure to</span><span style="color:#0288d1;"> add any batch_normalization ops </span><span style="color:#ff0000;">before</span><span style="color:#0288d1;"> getting the update_ops collection. Otherwise, update_ops will be empty, and training/inference will not work properly. For example:</span></span><br></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">  x_norm = tf.layers.batch_normalization(x, training=training)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">  <span class="hljs-comment"># ...</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">  <span class="hljs-keyword">with</span> tf.control_dependencies(update_ops): <span class="hljs-comment">#保证train_op在update_ops执行之后再执行。</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">    train_op = optimizer.minimize(loss)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><strong>3. 保存模型时:</strong></p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">var_list = tf.trainable_variables() </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">g_list = tf.global_variables()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">bn_moving_vars = [g <span class="hljs-keyword">for</span> g <span class="hljs-keyword">in</span> g_list <span class="hljs-keyword">if</span> <span class="hljs-string">'moving_mean'</span> <span class="hljs-keyword">in</span> g.name]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">bn_moving_vars += [g <span class="hljs-keyword">for</span> g <span class="hljs-keyword">in</span> g_list <span class="hljs-keyword">if</span> <span class="hljs-string">'moving_variance'</span> <span class="hljs-keyword">in</span> g.name]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">var_list += bn_moving_vars</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">saver = tf.train.Saver(var_list=var_list, max_to_keep=<span class="hljs-number">5</span>)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p><strong>用以保存滑动平均值,否则载入模型时,会出错。</strong></p><p><strong><br></strong></p><p><strong>预测:</strong></p><p><strong>设置training=False。(当训练时的batch_size设置为1时,training=False的测试效果不一定好,可能是由于训练时的batch_size太小,导致滑动平均值不稳定,因为使用滑动平均值去测试效果不好,反而设置为training=True效果更好。<span style="color:#ff0000;">可以当做一个超参数去尝试。</span></strong><span style="font-weight:bold;">)</span></p><p><span style="font-weight:bold;"><a href="https://stackoverflow.com/questions/46573345/how-to-correctly-use-the-tf-layers-batch-normalization-in-tensorflow" rel="nofollow" target="_blank">点击打开链接</a><br></span></p><p>这个人提出即使使用training=False,在测试时效果也不好,他尝试了在测试时用:</p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">with</span> tf.control_dependencies(update_ops):</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">    logits = yourmodel.infersence(inputs_)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><br><p><strong><a href="https://stackoverflow.com/questions/50602885/if-the-batch-size-equals-1-in-tf-layers-batch-normalization-will-it-works-cor/50605338#50605338" rel="nofollow" target="_blank">点击打开链接</a>这是一个当batch_size = 1时,batch_norm实际上是instance_norm的讲解。<br></strong></p>参考:https://www.cnblogs.com/hrlnw/p/7227447.html            </div>
            </div>
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值