Tensorflow sess.run 和 string_input_producer 挂起的解决方案

网上常见的解决方式是

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

但是要注意的是 queue 的启动要放在 queue 的定义之后
如:

sess = tf.Session(config=config)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

coord = tf.train.Coordinator()
# queue 的启动
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# queue 的定义
tmp = tf.train.string_input_producer(['./test.tfrecords'], shuffle=True)
reader = tf.TFRecordReader(options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB))
_, serialized_example = reader.read(tmp)
sess.run(serialized_example)

上面的代码会卡在sess.run()
如果修改成:

sess = tf.Session(config=config)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

coord = tf.train.Coordinator()

# queue 的定义
tmp = tf.train.string_input_producer(['./test.tfrecords'], shuffle=True)
reader = tf.TFRecordReader(options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB))
_, serialized_example = reader.read(tmp)

# queue 的启动
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sess.run(serialized_example)

就不会卡住了

(其实如果仔细读 log, 会发现 tensorflow 早就 warn 过了 WARNING:tensorflow:tf.train.start_queue_runners() was called when no queue runners were defined. You can safely remove the call to this deprecated function.)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
`tf.train.slice_input_producer()` 函数在 TensorFlow 2.0 版本中已经被弃用,取而代之的是 `tf.compat.v1.train.slice_input_producer()` 函数。如果在 TensorFlow 2.0 及之后的版本中使用 `tf.train.slice_input_producer()` 函数会收到警告信息。 在 TensorFlow 2.0 及之后的版本中,`tf.compat.v1.train.slice_input_producer()` 函数的使用方法与 `tf.train.slice_input_producer()` 函数相同,可以用于生成输入数据队列。例如: ```python import tensorflow as tf # 生成输入数据队列 data = [1, 2, 3, 4, 5] input_queue = tf.compat.v1.train.slice_input_producer([data], num_epochs=1, shuffle=True) # 读取队列中的数据 x = input_queue[0] # 创建会话,读取队列中的数据并打印 with tf.compat.v1.Session() as sess: # 初始化变量 sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer()) # 启动队列 coord = tf.train.Coordinator() threads = tf.compat.v1.train.start_queue_runners(coord=coord) # 读取数据并打印 try: while not coord.should_stop(): print(sess.run(x)) except tf.errors.OutOfRangeError: print('Done!') finally: coord.request_stop() coord.join(threads) ``` 需要注意的是,在 TensorFlow 2.0 及之后的版本中,`tf.compat.v1.train.slice_input_producer()` 函数返回的是一个元组,需要通过索引访问元素。同时,需要使用 `tf.compat.v1.Session()` 和 `tf.compat.v1.train.start_queue_runners()` 函数来启动队列。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值