一、起因
最近在看图像数据预处理,有很多不懂得地方,于是问了问度娘。请各位小主按顺序食用。
二、博客推荐
1.Session.run()和Tensor.eval()的区别:
https://blog.csdn.net/chengshuhao1991/article/details/78554743
2.四种类型数据的读取流程及API讲解和代码实现:
https://blog.csdn.net/chengshuhao1991/article/details/78655292
3.数据读取的三种方式:
https://blog.csdn.net/chengshuhao1991/article/details/78644966
4.TFRecords文件的存储与读取讲解及代码实现:
https://blog.csdn.net/chengshuhao1991/article/details/78656724
5.小主们有可能需要一下关于队列得相关知识:
https://blog.csdn.net/weixin_38715680/article/details/89889573
三、完整实例
首先用下面脚本生成一些模拟数据存储在TFRecord中:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author:Dr.Shang
import tensorflow as tf
# create the helping document for TFRecord
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# simulate millions of data are written into different documents
num_shards = 2 # total documents
instance_per_shard = 2
for i in range(num_shards):
filename = ('data.tfrecords-%.5d-of-%.5d' % (i, num_shards))
writer = tf.python_io.TFRecordWriter(filename)
# encapsulate data into the Example structure and write it into TFRecord
for j in range(instance_per_shard):
# Example structure only shows the current example which document is belongs to, and what is the turn of the
# document example
example = tf.train.Example(features=tf.train.Features(feature={
'i': _int64_feature(i),
'j': _int64_feature(j)
}))
"""
pay attention to the structure :
tf.train.Example(features=tf.train.Features(feature={
'i': _int64_feature(i),
'j': _int64_feature(j)
}))
一个嵌套一个。
解释如下:
(1)tf.train.Example(features = None)
写入tfrecords文件
features : tf.train.Features类型的特征实例
return : example协议格式块
(2)tf.train.Features(feature = None)
构造每个样本的信息键值对
feature : 字典数据,key为要保存的名字,value为tf.train.Feature实例
return : Features类型
(3)tf.train.Feature(**options)
options可以选择如下三种格式数据:
bytes_list = tf.train.BytesList(value = [Bytes])
int64_list = tf.train.Int64List(value = [Value])
float_list = tf.trian.FloatList(value = [Value]
"""
writer.write(example.SerializeToString())
writer.close()
执行过后,会在指定的目录中生成下面这两个文件:
接着呈现完美的样例:利用线程和队列从TFRecord中读取数据:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# author:Dr.Shang
import tensorflow as tf
# use tf.train.match_filenames_once to get files list
files = tf.train.match_filenames_once("data.tfrecords-*")
# create input-queue by using tf.train.string_input_producer(flies)
filename_queue = tf.train.string_input_producer(files)
# read and analysis a example
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64),
})
with tf.Session() as sess:
tf.local_variables_initializer().run()
print(sess.run(files)) # print the file list
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# get data
for i in range(6):
print(sess.run([features['i'], features['j']]))
coord.request_stop()
coord.join(threads)
output:
[b'.\\data.tfrecords-00000-of-00002' b'.\\data.tfrecords-00001-of-00002']
[1, 0]
[1, 1]
[0, 0]
[0, 1]
[0, 0]
[0, 1]