tensorflow variable


1.共享变量

例一

with tf.device('/cpu:0'):
    with tf.variable_scope('yiqingyang') as sco:
        w_1 = tf.get_variable(name="w_1",initializer=1.0)
        print(w_1)
        tf.get_variable_scope().reuse_variables()

        with tf.device('/gpu:0'):
            w_2 = tf.get_variable(name="w_1",initializer=1.0)
            print(w_2)
            w_3 = w_2.assign_add(1)
            print(w_3)
            w_4 = w_3*3
            print(w_4)

init_op = tf.global_variables_initializer()
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init_op)
    print(sess.run([w_4,w_2]))

输出

Tensor("yiqingyang/w_1/read:0", shape=(), dtype=float32, device=/device:CPU:0)
Tensor("yiqingyang/w_1/read:0", shape=(), dtype=float32, device=/device:CPU:0)
Tensor("yiqingyang/AssignAdd:0", shape=(), dtype=float32_ref, device=/device:CPU:0)
Tensor("yiqingyang/mul:0", shape=(), dtype=float32, device=/device:GPU:0)

[6.0, 2.0]

从上面结果可知w_1放在cpu:0 上,w_2虽然定义在gpu限制内,但还是分配到了cpu:0上,因为tf.get_variable时,发现有同名变量,所以就直接引用了。而w_4操作却是放在GPU:0上的。 另一个发现就是共享变量不受设备的限制,即是否有with tf.device限制不会影响变量名字。

例二

with tf.variable_scope('yq') as sco:
    w_1 = tf.get_variable(name="w_1",initializer=1.0)
    print(w_1)
    tf.get_variable_scope().reuse_variables()

with tf.variable_scope('yq') as sc1:
    w_2 = tf.get_variable(name="w_1",initializer=1.0)
    print(w_2)

init_op = tf.global_variables_initializer()
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init_op)
    print(sess.run([w_1,w_2]))

运行时报错,显示指定变量名已经被定义,但是下面这种方式就可以,且共享

with tf.variable_scope('yq') as sco:
    w_1 = tf.get_variable(name="w_1",initializer=1.0)
    print(w_1)
    tf.get_variable_scope().reuse_variables()
#或者写成with tf.variable_scope('yq',reuse=True) as sc1:
with tf.variable_scope(sco) as sc1:
    w_2 = tf.get_variable(name="w_1",initializer=1.0)
    print(w_2)

init_op = tf.global_variables_initializer()
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init_op)
    print(sess.run([w_1,w_2]))

2. 分布式共享变量

example.py 文件内容如下

# -*- coding=utf-8 -*-
import tensorflow as tf
import numpy as np
from time import sleep
# Configuration of cluster 
ps_hosts = [ "localhost:2229" ]
worker_hosts = [ "localhost:2228" ]
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

FLAGS = tf.app.flags.FLAGS

def main(_):
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        with tf.device("/job:ps/task:0"):
            with tf.variable_scope('param_yi') as scopes:
                w = tf.get_variable(name="w_1",initializer=1.0)
        init_op = tf.global_variables_initializer()
        sess_ps = tf.Session(server.target)
        print("Server 1: waiting for connection...")
        while len(sess_ps.run(tf.report_uninitialized_variables())) > 0:
            print("Server 1: waiting for initialization...")
            sleep(1.0)
        print("Server 1: variables initialized!")
        sleep(5)
        print 'w:'
        print sess_ps.run(w)
        server.join()

    elif FLAGS.job_name == "worker":

        # 选择变量存储位置和op执行位置,这里全部放在worker的第一个task上
        with tf.device("/job:ps/task:0"):
            with tf.variable_scope('param_yi') as scopes:
                w = tf.get_variable(name="w_1",initializer=1.0)
        init_op = tf.global_variables_initializer()

        # 选择创建session使用的master
        with tf.Session(server.target) as sess:
            print 'worker init'
            sess.run(init_op)
            print 'w:'
            print sess.run(w.assign_add(1.0))
            server.join()

if __name__ == "__main__":
    tf.app.run()

在同一台机器的一个终端里先运行

python example.py --job_name=ps --task_index=0

在另一个终端里输入

python example.py --job_name=worker --task_index=0

两台机器最终的返回结果都是2,所以ps 和worker进程里的w变量是共享的。

但如果换成下面这种方式,则ps task一直等待变量初始化

# -*- coding=utf-8 -*-
import tensorflow as tf
import numpy as np
from time import sleep
# Configuration of cluster 
ps_hosts = [ "localhost:2229" ]
worker_hosts = [ "localhost:2228" ]
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

FLAGS = tf.app.flags.FLAGS

def main(_):
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        with tf.device("/job:ps/task:0"):
            with tf.variable_scope('param_yi') as scopes:
                w = tf.get_variable(name="w_1",initializer=1.0)
        init_op = tf.global_variables_initializer()
        sess_ps = tf.Session(server.target)
        print("Server 1: waiting for connection...")
        while len(sess_ps.run(tf.report_uninitialized_variables())) > 0:
            print("Server 1: waiting for initialization...")
            sleep(1.0)
        print("Server 1: variables initialized!")
        sleep(5)
        print 'w:'
        print sess_ps.run(w)
        server.join()

    elif FLAGS.job_name == "worker":

        # 选择变量存储位置和op执行位置,这里全部放在worker的第一个task上
        with tf.device("/job:worker/task:0"):
            with tf.variable_scope('param_yi') as scopes:
                w = tf.get_variable(name="w_1",initializer=1.0)
        init_op = tf.global_variables_initializer()

        # 选择创建session使用的master
        with tf.Session(server.target) as sess:
            print 'worker init'
            sess.run(init_op)
            print 'w:'
            print sess.run(w.assign_add(1.0))
            server.join()

if __name__ == "__main__":
    tf.app.run()

不选择设备的方式也可以共享

# -*- coding=utf-8 -*-
import tensorflow as tf
import numpy as np
from time import sleep
# Configuration of cluster 
ps_hosts = [ "localhost:2229" ]
worker_hosts = [ "localhost:2228" ]
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")

FLAGS = tf.app.flags.FLAGS

def main(_):
    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == "ps":
        #with tf.device("/job:ps/task:0"):
        #with tf.variable_scope('param_yi') as scopes:
        w = tf.get_variable(name="w_1",initializer=1.0)
        print w
        init_op = tf.global_variables_initializer()
        sess_ps = tf.Session(server.target)
        print("Server 1: waiting for connection...")
        while len(sess_ps.run(tf.report_uninitialized_variables())) > 0:
            print("Server 1: waiting for initialization...")
            sleep(1.0)
        print("Server 1: variables initialized!")
        sleep(5)
        print 'w:'
        print sess_ps.run(w)
        server.join()

    elif FLAGS.job_name == "worker":

        # 选择变量存储位置和op执行位置,这里全部放在worker的第一个task上
        #with tf.device("/job:worker/task:0"):
        #with tf.variable_scope('param_yi') as scopes:
        w = tf.get_variable(name="w_1",initializer=1.0)
        print w
        init_op = tf.global_variables_initializer()

        # 选择创建session使用的master
        with tf.Session(server.target) as sess:
            print 'worker init'
            sess.run(init_op)
            print 'w:'
            print sess.run(w.assign_add(1.0))
            server.join()

if __name__ == "__main__":
    tf.app.run()

两个都是输出2

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiqingyang2012

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值