the basic approach to read dataset(TFRecord) with iterator in Tensorflow

1. the three steps for reading datasets
  1) define the constructor method of dataset;
  2) define the iterator;
  3) to obtain the data tensor from iterator by using get_next method.

For example :

import  tensorflow as tf

input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
y = x*x
with tf.Session() as sess:
    for i in range(len(input_data)):
        print(sess.run(y))
Then, TextLineDataset() function can be used to reading data by a line, which is usually used to process the task in natural language analysis. there is given a example as following
input_files = ['D:/path/to/flowers/input_file2.txt','D:/path/to/flowers/input_file2.txt']
Dataset = tf.data.TextLineDataset(input_files)
iterator = Dataset.make_one_shot_iterator()
x = iterator.get_next()
with tf.Session() as sess:
    # To return a string tensor, which represents a line in a file 
    for i in range(25) :
        print(sess.run(x))
the input_files can be create with more than one a txt file, it's like a string array, which means the dataset can be created with serveral files

the basic approach to read data in TFRecord format:
import  tensorflow as tf
# define a approach to decode TFRecord file
def parser(record):
    features = tf.parse_single_example(
        record,
        features={
            'image_raw':tf.FixedLenFeature([],tf.string),
            'pixels':tf.FixedLenFeature([],tf.int64),
            'label':tf.FixedLenFeature([],tf.int64)
        })
    decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
    retyped_images = tf.cast(decoded_images, tf.float32)

    images = tf.reshape(retyped_images, [784])
    labels = tf.cast(features['label'],tf.int32)
    pixels = tf.cast(features['pixels'],tf.int32)
    return images, labels, pixels

# make a dataset  from  TFRecord files , it can provide serveral files here.
input_files = ["D:/path/1output_test.tfrecords", ]
dataset = tf.data.TFRecordDataset(input_files)

# map() refers to a function  that  decode each piece of data with parser method in a dataset
dataset = dataset.map(parser)

# define a iterator to iterate the data-set
iterator = dataset.make_one_shot_iterator()

# to obtain data
image, label, pixels = iterator.get_next()

with tf.Session() as sess:
    # the while can iterate all data without the exactly size of dataset
    while True:
        try:
            x, y, z  = sess.run([image, label, pixels])
            print(y, z)
        except tf.errors.OutOfRangeError:
            break


'''
input_files = tf.placeholder(tf.string)
dataset = tf.data. TFRecordDataset(input_files)
dataset = dataset.map(parser)
iterator = dataset.make_initializable_iterator()
image, label, pixels = iterator.get_next()

 

 



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值