TensorFLow变量管理与变量共享

<div id="article_content" class="article_content clearfix csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
								<div class="article-copyright">
					版权声明:本文为博主原创文章,未经博主允许不得转载。					https://blog.csdn.net/Michael__Corleone/article/details/78906318				</div>
								            <div id="content_views" class="markdown_views">
							<!-- flowchart 箭头图标 勿删 -->
							<svg xmlns="http://www.w3.org/2000/svg" style="display: none;"><path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path></svg>
							<p>今天将tf.Variable和tf.get_variable变量的使用记录一下,在实现gan时着实踩了很深的坑,总是效果不好,也没有报错,结果发现是共享权重没有处理好,最后终于整好了,贼开心呢,希望大家能够避免踩坑。</p>

<h1 id="1tfvariable的使用"><a name="t0"></a>1、tf.Variable的使用</h1>

<p>调用方式:</p>



<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering">weights = tf.Variable(tf.constant(<span class="hljs-number">0.1</span>, shape = shape), name = <span class="hljs-string">"weights"</span>)</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li></ul></pre>

<h1 id="2tfgetvariable的使用"><a name="t1"></a>2、tf.get_variable的使用</h1>

<p>调用方式:</p>



<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering">weights = tf.get_variable(<span class="hljs-string">"weights"</span>, shape,
         initializer = tf.truncated_normal_initializer(stddev = <span class="hljs-number">0.1</span>))</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li></ul></pre>

<h1 id="3两者区别"><a name="t2"></a>3、两者区别</h1>



<h2 id="31">3.1</h2>

<p>tf.Variable,当重复调用时,它会自动创建新的变量名:</p>

<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">()</span>:</span>
    <span class="hljs-comment">#在layer1命名空间内创建变量,默认reuse=False</span>
    <span class="hljs-keyword">with</span> tf.variable_scope(<span class="hljs-string">'D_layer1'</span>):
        weights1 = tf.Variable(tf.constant(<span class="hljs-number">0.1</span>, shape = [<span class="hljs-number">5</span>]), name = <span class="hljs-string">"weights"</span>) 
        name1 = weights1.name
    <span class="hljs-comment">#在layer2命名空间内创建变量,默认reuse=False</span>
    <span class="hljs-keyword">with</span> tf.variable_scope(<span class="hljs-string">'D_layer2'</span>):
        weights2 = tf.Variable(tf.constant(<span class="hljs-number">0.1</span>, shape = [<span class="hljs-number">5</span>]), name = <span class="hljs-string">"weights"</span>) 
        name2 = weights2.name
    <span class="hljs-keyword">return</span> name1, name2</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li></ul></pre>

<p>tf.variable_scope(‘D_layer1’)会创建一个名为D_layer1的空间名,其下的所有变量名是在它的子空间来命名,如上函数,重复调用结果如下:</p>



<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering"><span class="hljs-keyword">import</span> variabletest
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np

name11, name12 = variabletest.test()
name21, name22 = variabletest.test()
print(name11)
print(name12)
print(name21)
print(name22)

D_layer1/weights:<span class="hljs-number">0</span>
D_layer2/weights:<span class="hljs-number">0</span>
D_layer1_1/weights:<span class="hljs-number">0</span>
D_layer2_1/weights:<span class="hljs-number">0</span></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li></ul></pre>

<h2 id="32实现共享变量"><a name="t4"></a>3.2、实现共享变量</h2>

<p>tf.get_variable,当重复调用时,它会自动创建新的变量名:</p>

<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">test</span><span class="hljs-params">(reuse)</span>:</span>
    <span class="hljs-comment">#在layer1命名空间内创建变量,默认reuse=False</span>
    <span class="hljs-keyword">with</span> tf.variable_scope(<span class="hljs-string">'D_layer1'</span>, reuse = reuse):
        weights1 = tf.get_variable(<span class="hljs-string">"weights"</span>, [<span class="hljs-number">5</span>], initializer = tf.truncated_normal_initializer(stddev = <span class="hljs-number">0.1</span>)) 
        name1 = weights1.name
    <span class="hljs-comment">#在layer2命名空间内创建变量,默认reuse=False</span>
    <span class="hljs-keyword">with</span> tf.variable_scope(<span class="hljs-string">'D_layer2'</span>, reuse = reuse):
        weights2 = tf.get_variable(<span class="hljs-string">"weights"</span>, [<span class="hljs-number">5</span>], initializer = tf.truncated_normal_initializer(stddev = <span class="hljs-number">0.1</span>)) 
        name2 = weights2.name
    <span class="hljs-keyword">return</span> name1, name2</code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li></ul></pre>

<p>在类似gan网络中,我们需要共享权重,这样就会多次调用同一个前向传播的函数,但是若使用tf.Variable达不到共享权重的目的,除非将tf.Variable放置主函数中,但是这样封装性不好,所以就可以使用tf.get_variable,配合tf.variable_scope一起使用,结果如下:</p>



<pre class="prettyprint" name="code"><code class="language-python hljs  has-numbering"><span class="hljs-keyword">import</span> variabletest
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np

name11, name12 = variabletest.test(<span class="hljs-keyword">False</span>)
name21, name22 = variabletest.test(<span class="hljs-keyword">True</span>)
print(name11)
print(name12)
print(name21)
print(name22)

D_layer1/weights:<span class="hljs-number">0</span>
D_layer2/weights:<span class="hljs-number">0</span>
D_layer1/weights:<span class="hljs-number">0</span>
D_layer2/weights:<span class="hljs-number">0</span></code><ul class="pre-numbering" style=""><li style="color: rgb(153, 153, 153);">1</li><li style="color: rgb(153, 153, 153);">2</li><li style="color: rgb(153, 153, 153);">3</li><li style="color: rgb(153, 153, 153);">4</li><li style="color: rgb(153, 153, 153);">5</li><li style="color: rgb(153, 153, 153);">6</li><li style="color: rgb(153, 153, 153);">7</li><li style="color: rgb(153, 153, 153);">8</li><li style="color: rgb(153, 153, 153);">9</li><li style="color: rgb(153, 153, 153);">10</li><li style="color: rgb(153, 153, 153);">11</li><li style="color: rgb(153, 153, 153);">12</li><li style="color: rgb(153, 153, 153);">13</li><li style="color: rgb(153, 153, 153);">14</li><li style="color: rgb(153, 153, 153);">15</li></ul></pre>

<p>当tf.variable_scope的reuse设置为False时,他会自动创建新的变量,当为True时,他会从已有的变量中查询并使用。从而上述代码即可完成变量的共享使用。。。</p>            </div>
						<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-a47e74522c.css" rel="stylesheet">
                </div>

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值