#coding=utf-8
import json
import numpy as np
import tensorflow as tf
from IPython import embed
file_queue = tf.train.string_input_producer(
['test.tfrecord'],
shuffle=False,
num_epochs=1
)
reader = tf.TFRecordReader()
_, serialized = reader.read(file_queue)
contexts, features = tf.parse_single_sequence_example(
serialized,
context_features={
"id": tf.FixedLenFeature([], tf.string),
"labels": tf.VarLenFeature(tf.int64)
},
sequence_features={
"rgb": tf.FixedLenSequenceFeature([], tf.string)
}
)
conf = tf.ConfigProto()
conf.gpu_options.allow_growth = True
#shit = tf.train.shuffle_batch(
# [contexts, features],
# batch_size=1,
# capacity=10 + 3,
# min_after_dequeue=10)
with open('completion.json', 'r') as f:
comple_dict = json.load(f)
miss = 0
hit = 0
with tf.Session(config=conf) as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
while True:
ctx_v, feats_v = sess.run([contexts, features])
if str(ctx_v['id'][:-4], encoding='utf-8') in comple_dict:
hit += 1
else:
miss += 1
print('hit: %d, miss: %d' % (hit, miss))
except tf.errors.OutOfRangeError as e:
print(e)
print('finish')
embed()
1)如果使用了Num_epochs=1,必须得tf.local_variables_initializer(),不然会一开始就outofrangeerror