我无法使用“new”(TensorFlow v1.4)数据集API读取TFRecord格式的图像数据.我认为问题是我在尝试阅读时以某种方式消耗整个数据集而不是单个批次.我有一个使用批处理/文件队列API执行此操作的示例:
https://github.com/gnperdue/TFExperiments/tree/master/conv(在示例中,我正在运行分类器,但读取TFRecord映像的代码位于DataReaders.py类中).
我相信问题的功能是:
def parse_mnist_tfrec(tfrecord, features_shape):
tfrecord_features = tf.parse_single_example(
tfrecord,
features={
'features': tf.FixedLenFeature([], tf.string),
'targets': tf.FixedLenFeature([], tf.string)
}
)
features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
features = tf.reshape(features, features_shape)
features = tf.cast(features, tf.float32)
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
targets = tf.cast(targets, tf.float32)
return features, targets
class MNISTDataReaderDset:
def __init__(self, data_reader_dict):
# doesn't matter here
def batch_generator(self, num_epochs=1):
def parse_fn(tfrecord):
return parse_mnist_tfrec(
tfrecord, self.name, self.features_shape
)
dataset = tf.data.TFRecordDataset(
self.filenames_list, compression_type=self.compression_type
)
dataset = dataset.map(parse_fn)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(self.batch_size)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
然后,在使用中:
batch_features, batch_labels = \
data_reader.batch_generator(num_epochs=1)
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
# look at 3 batches only
for _ in range(3):
labels, feats = sess.run([
batch_labels, batch_features
])
这会产生如下错误:
[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
Input to reshape is a tensor with 50000 values, but the requested shape has 1
[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
有没有人有任何想法?
我有一个关于读者示例中的完整代码的要点,以及TFRecord文件的链接(我们的老朋友,MNIST,TFRecord形式):
谢谢!
编辑 – 我也尝试过flat_map,例如:
def batch_generator(self, num_epochs=1):
"""
TODO - we can use placeholders for the list of file names and
init with a feed_dict when we call `sess.run` - give this a
try with one list for training and one for validation
"""
def parse_fn(tfrecord):
return parse_mnist_tfrec(
tfrecord, self.name, self.features_shape
)
dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)
dataset = dataset.flat_map(
lambda filename: (
tf.data.TFRecordDataset(
filename, compression_type=self.compression_type
).map(parse_fn).batch(self.batch_size)
)
)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
我也试过只使用一个文件而不是列表(在我上面的第一种方法中).无论如何,似乎TF总是希望将整个文件包含在TFRecordDataset中,并且不会对单个记录进行操作.
最佳答案 好的,我想出来了 – 上面的代码很好.问题是我创建TFRecords的脚本.基本上,我有一个像这样的块
def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
writer = tf.python_io.TFRecordWriter(tfrecord_file)
tfeat, ttarg = get_binary_data(reader, start_idx, stop_idx)
example = tf.train.Example(
features=tf.train.Features(
feature={
'features': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[tfeat])
),
'targets': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[ttarg])
)
}
)
)
writer.write(example.SerializeToString())
writer.close()
而我需要这样一个块:
def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file):
writer = tf.python_io.TFRecordWriter(tfrecord_file)
for idx in range(start_idx, stop_idx):
tfeat, ttarg = get_binary_data(reader, idx)
example = tf.train.Example(
features=tf.train.Features(
feature={
'features': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[tfeat])
),
'targets': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[ttarg])
)
}
)
)
writer.write(example.SerializeToString())
writer.close()
也就是说 – 当我需要在数据中为每个例子制作一个时,我基本上将整个数据块写成一个巨大的TFRecord.
事实证明,如果你在旧文件和批处理队列API中都这样做,一切正常 – 像tf.train.batch这样的函数是自动神奇的“智能”足以分割大块或连接大量单个 – 示例根据您提供的内容记录到批处理中.当我修改了制作TFRecords文件的代码时,我不需要更改旧文件和批处理队列代码中的任何内容,它仍然可以正常使用TFRecords文件.但是,数据集API对此差异很敏感.这就是为什么在我上面的代码中它似乎总是消耗整个文件 – 因为整个文件确实是一个很大的TFRecord.