1. the three steps for reading datasets
1) define the constructor method of dataset;
2) define the iterator;
3) to obtain the data tensor from iterator by using get_next method.
For example :
import tensorflow as tf
input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
y = x*x
with tf.Session() as sess:
for i in range(len(input_data)):
print(sess.run(y))
input_files = ['D:/path/to/flowers/input_file2.txt','D:/path/to/flowers/input_file2.txt']
Dataset = tf.data.TextLineDataset(input_files)
iterator = Dataset.make_one_shot_iterator()
x = iterator.get_next()
with tf.Session() as sess:
# To return a string tensor, which represents a line in a file
for i in range(25) :
print(sess.run(x))
the input_files can be create with more than one a txt file, it's like a string array, which means the dataset can be created with serveral files
import tensorflow as tf
# define a approach to decode TFRecord file
def parser(record):
features = tf.parse_single_example(
record,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
})
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)
images = tf.reshape(retyped_images, [784])
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)
return images, labels, pixels
# make a dataset from TFRecord files , it can provide serveral files here.
input_files = ["D:/path/1output_test.tfrecords", ]
dataset = tf.data.TFRecordDataset(input_files)
# map() refers to a function that decode each piece of data with parser method in a dataset
dataset = dataset.map(parser)
# define a iterator to iterate the data-set
iterator = dataset.make_one_shot_iterator()
# to obtain data
image, label, pixels = iterator.get_next()
with tf.Session() as sess:
# the while can iterate all data without the exactly size of dataset
while True:
try:
x, y, z = sess.run([image, label, pixels])
print(y, z)
except tf.errors.OutOfRangeError:
break
'''
input_files = tf.placeholder(tf.string)
dataset = tf.data. TFRecordDataset(input_files)
dataset = dataset.map(parser)
iterator = dataset.make_initializable_iterator()
image, label, pixels = iterator.get_next()