tensorflow java 线程_在Keras和Tensorflow中复制模型以进行多线程设置

我想在Keras和TensorFlow中实现actor-critic的异步版本。我正在使用Keras作为构建网络层的前端(我正在使用tensorflow直接更新参数)。我有一个global_model和一个主要的tensorflow会话。但是在每个线程中,我创建一个local_model,它复制global_model中的参数。我的代码看起来是这样的在Keras和Tensorflow中复制模型以进行多线程设置

def main(args):

config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)

sess = tf.Session(config=config)

K.set_session(sess) # K is keras backend

global_model = ConvNetA3C(84,84,4,num_actions=3)

threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

for t in threads:

t.start()

def a3c_thread(i, sess, global_model):

K.set_session(sess) # registering a session for each thread (don't know if it matters)

local_model = ConvNetA3C(84,84,4,num_actions=3)

sync = local_model.get_from(global_model) # I get the error here

#in the get_from function I do tf.assign(dest.params[i], src.params[i])

我从Keras

UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via K.set_session(sess)

接着是tensorflow错误的tf.assign操作说操作必须在同一张图让用户警告。

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

我不确定发生了什么问题。

感谢

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值