一、创建
创建时需要按照上篇文章介绍的三种数据格式来写。
import tensorflow as tf
import numpy as np
TFFilePath = 'test.tfrecords'
writer = tf.python_io.TFRecordWriter(TFFilePath)
for _ in range(100):
randomArray = np.random.random((1, 3)) # 生成随机数组作为测试数据
array_raw = randomArray.tobytes() # 转化为二进制形式
example = tf.train.Example(features=tf.train.Features(feature={
'lable':tf.train.Feature(int64_list=tf.train.Int64List(value=[1])),
'img':tf.train.Feature(bytes_list=tf.train.BytesList(value=[array_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
二、读取
读取时需要保持与创建时相同的结构
import tensorflow as tf
files = tf.train.string_input_producer(['test.tfrecords'], num_epochs=None)
reader = tf.TFRecordReader() # 创建读对象
_, serialized_example = reader.read(files)
features = tf.parse_single_example( # 读取单个
serialized_example,
features={
'lable':tf.FixedLenFeature([], tf.int64),
'img':tf.FixedLenFeature([], tf.string)
}
)
image = tf.decode_raw(features['img'], tf.uint8)
lable = tf.cast(features['lable'], tf.int32)
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
print(sess.run([image, lable]))
其中,tf.train.string_input_producer返回一个字符串队列,