利用数据集读取数据有三个基本步骤:
- 定义数据及的构造方法,如tf.data.TFRecordDataset(input_files)
- 定义遍历器,如one_shot_iterator,initializable_iterator
- 使用get_next()获取tensor
例:
import tensorflow as tf
def parser(record):
features = tf.parse_single_example(
record,
features={
'feat1':tf.FixedLenFeature([],tf.int64),
'feat2':tf.FixedLenFeature([],tf.int64)
}
)
return features['feat1'],features['feat2']
#数据集可以是一个tensor,或者文本文件
#若是tensor,则使用tf.data.from_tensor_slices(input_data)
#若是文本文件,则使用tf.data.TextLineDataset(input_files)
input_files = ['file1','file2']
dataset = tf.data.TFRecordDataset(input_files)
#由于tfrecords读取出来的是二进制数据,需要对每个数据进行解析,得到想要的格式
#这里使用映射函数对每个数据进行解析
dataset = dataset.map(parser)
#通过一个迭代器获取数据
iterator = dataset.make_one_shot_iterator()
feat1,feat2 = iterator.get_next()
with tf.Session() as sess:
for i