java解析tfrecord_使用新的TensorFlow Dataset API读取TFRecord图像数据

我无法使用"new"(TensorFlow v1.4)数据集API读取TFRecord格式的图像数据 . 我认为问题是我在尝试阅读时以某种方式消耗整个数据集而不是单个批次 . 我有一个使用批处理/文件队列API执行此操作的示例:https://github.com/gnperdue/TFExperiments/tree/master/conv(在示例中,我正在运行分类器,但读取TFRecord图像的代码位于 DataReaders.py 类中) .

我相信问题的功能是:

def parse_mnist_tfrec(tfrecord, features_shape):

tfrecord_features = tf.parse_single_example(

tfrecord,

features={

'features': tf.FixedLenFeature([], tf.string),

'targets': tf.FixedLenFeature([], tf.string)

}

)

features = tf.decode_raw(tfrecord_features['features'], tf.uint8)

features = tf.reshape(features, features_shape)

features = tf.cast(features, tf.float32)

targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)

targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)

targets = tf.cast(targets, tf.float32)

return features, targets

class MNISTDataReaderDset:

def __init__(self, data_reader_dict):

# doesn't matter here

def batch_generator(self, num_epochs=1):

def parse_fn(tfrecord):

return parse_mnist_tfrec(

tfrecord, self.name, self.features_shape

)

dataset = tf.data.TFRecordDataset(

self.filenames_list, compression_type=self.compression_type

)

dataset = dataset.map(parse_fn)

dataset = dataset.repeat(num_epochs)

dataset = dataset.batch(self.batch_size)

iterator = dataset.make_one_shot_iterator()

batch_features, batch_labels = iterator.get_next()

return batch_features, batch_labels

然后,在使用中:

batch_features, batch_labels = \

data_reader.batch_generator(num_epochs=1)

sess.run(tf.local_variables_initializer())

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(coord=coord)

try:

# look at 3 batches only

for _ in range(3):

labels, feats = sess.run([

batch_labels, batch_features

])

这会产生如下错误:

[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]

Input to reshape is a tensor with 50000 values, but the requested shape has 1

[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]

[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

有没有人有任何想法?

我有一个关于读者示例中的完整代码的要点,以及TFRecord文件的链接(我们的老朋友,MNIST,TFRecord形式):

谢谢!

编辑 - 我也试过 flat_map ,例如:

def batch_generator(self, num_epochs=1):

"""

TODO - we can use placeholders for the list of file names and

init with a feed_dict when we call `sess.run` - give this a

try with one list for training and one for validation

"""

def parse_fn(tfrecord):

return parse_mnist_tfrec(

tfrecord, self.name, self.features_shape

)

dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)

dataset = dataset.flat_map(

lambda filename: (

tf.data.TFRecordDataset(

filename, compression_type=self.compression_type

).map(parse_fn).batch(self.batch_size)

)

)

dataset = dataset.repeat(num_epochs)

iterator = dataset.make_one_shot_iterator()

batch_features, batch_labels = iterator.get_next()

return batch_features, batch_labels

我也试过只使用一个文件而不是列表(在我上面的第一种方法中) . 无论如何,似乎TF总是希望将整个文件放入 TFRecordDataset 并且不会对单个记录进行操作 .

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值