import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class tfrecords_file(object):
def __init__(self, file_list):
self.file_list = file_list
self.height = 32
self.width = 32
self.channels = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channels
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
"""
加载二进制文件
:return: label_batch 批处理的标签, image_batch 批处理的特征
"""
file_queue = tf.train.string_input_producer(self.file_list)
reader = tf.FixedLengthRecordReader(record_bytes=self.bytes)
key, value = reader.read(file_queue)
label_image = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
image_reshape = tf.reshape(image, [self.height, self.width, self.channels])
label_batch, image_batch = tf.train.batch([label, image_reshape], batch_size=20, num_threads=1, capacity=20)
return label_batch, image_batch
def save_to_tfrecords(self, label_batch, image_batch):
"""
存储为tfrecords格式的数据
:param label_batch: 样本标签
:param image_batch: 样本特征
:return: None
"""
tf_writer = tf.python_io.TFRecordWriter('../tmp/tfrecords/cifar.tfrecords')
for i in range(20):
label = label_batch[i].eval()[0]
image = image_batch[i].eval().tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
tf_writer.write(example.SerializeToString())
tf_writer.close()
return None
def read_from_tfrecords(self, rec_file_list):
"""
加载tfrecords文件
:param rec_file_list: 文件列表
:return: image_batch 批处理的特征, label_batch 批处理的标签
"""
file_queue = tf.train.string_input_producer(rec_file_list)
reader = tf.TFRecordReader()
key, value = reader.read(file_queue)
example = tf.parse_single_example(value, features={
'image': tf.FixedLenFeature(shape=[], dtype=tf.string),
'label': tf.FixedLenFeature(shape=[], dtype=tf.int64)
})
image = tf.decode_raw(example['image'], tf.uint8)
label = tf.cast(example['label'], tf.int32)
print(image, label)
image_reshape = tf.reshape(image, [self.height, self.width, self.channels])
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=20, num_threads=1, capacity=20)
return image_batch, label_batch
if __name__ == '__main__':
file_names = os.listdir('../data/cifar-10-binary')
file_list = [os.path.join('../data/cifar-10-binary', file_name) for file_name in file_names if
file_name[-3:] == 'bin']
tfrecords = tfrecords_file(file_list)
label_batch, image_batch = tfrecords.read_and_decode()
rec_file_names = os.listdir('../tmp/tfrecords/')
rec_file_list = [os.path.join('../tmp/tfrecords/', file_name) for file_name in rec_file_names]
image, label = tfrecords.read_from_tfrecords(rec_file_list)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
print(sess.run([image, label]))
coord.request_stop()
coord.join(threads)
注意
- 数据写出为tfrecords格式的数据时,需要构造每个样本的Example协议块。
- Example协议块类似于字典格式的数据,上述代码中,'image’是键,value=[image]是值。键用于解码时取数据。
- 解码时,如果加载的数据类型是string类型,需要解码,int和float类型不需要解码。