TFRecord简介
Tensorflow提供了一种统一的格式来存储数据,这个格式就是TFRecord。TFRecord文件是以二进制形式存储数据,适合以串行的方式读取大批量数据。
tfrecords的创建很简单,就是将每一组“样本数据”组装成一个Example对象,这个对象是遵循protocolbuffer协议的;然后将这个Example对象序列化成字符串;最后用tf.python_io.TFRecordWriter写入相应的tfrecords文件即可。大致步骤如下:
第一步:获取原始数据,一般使用numpy或者是pandas进行一些处理
第二步:使用tf.python_io.TFRecordWriter类定义一个tfrecords文件
第三步:将每一条样本数据按照相应的特征组织好,即将样本数据组织成Example的过程,这是整个操作流程的核心部分,相对较复杂
第四步:将组织好的Example写入进tfrecords文件,并关闭tfrecords文件即可**
PNet TFRecord文件生成**
运行gen_PNet_tfrecords.py即可生成PNet的TFRecord,数据逻辑基本如下
def run(dataset_dir, net, output_dir, name='MTCNN', shuffling=False):
"""Runs the conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
output_dir: Output directory.
"""
#tfrecord name
tf_filename = _get_output_filename(output_dir, name, net)
if tf.gfile.Exists(tf_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# GET Dataset, and shuffling.从之前合并train_PNet_landmark.txt文件中读取图片的路径,label以及landmark存储到dataset list中
dataset = get_dataset(dataset_dir, net=net)
# filenames = dataset['filename']
if shuffling:
tf_filename = tf_filename + '_shuffle'
#random.seed(12345454)
random.shuffle(dataset)
# Process dataset files.
# write the data to tfrecord
#创建tfrecord_writer
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
for i, image_example in enumerate(dataset):
if (i+1) % 100 == 0:
sys.stdout.write('\r>> %d/%d images has been converted' % (i+1, len(dataset)))
#sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(dataset)))
sys.stdout.flush()
filename = image_example['filename']
_add_to_tfrecord(filename, image_example, tfrecord_writer)
这里调用了_add_to_tfrecord函数,代码如下
def _add_to_tfrecord(filename, image_example, tfrecord_writer):
"""Loads data from image and annotations files and add them to a TFRecord.
Args:
filename: Dataset directory;
name: Image name to add to the TFRecord;
tfrecord_writer: The TFRecord writer to use for writing.
"""
#print('---', filename)
#imaga_data:array to string
#height:original image's height
#width:original image's width
#image_example dict contains image's info
image_data, height, width = _process_image_withoutcoder(filename)
#**调用_convert_to_example_simple函数,将数据按照相应的特征组织好,即将样本数据组织成Example的过程**
example = _convert_to_example_simple(image_example, image_data)
tfrecord_writer.write(example.SerializeToString())
_convert_to_example_simple函数如下
def _convert_to_example_simple(image_example, image_buffer):
"""
covert to tfrecord file
:param image_example: dict, an image example
:param image_buffer: string, JPEG encoding of RGB image
:param colorspace:
:param channels:
:param image_format:
:return:
Example proto
"""
# filename = str(image_example['filename'])
# class label for the whole image
class_label = image_example['label']
bbox = image_example['bbox']
roi = [bbox['xmin'],bbox['ymin'],bbox['xmax'],bbox['ymax']]
landmark = [bbox['xlefteye'],bbox['ylefteye'],bbox['xrighteye'],bbox['yrighteye'],bbox['xnose'],bbox['ynose'],
bbox['xleftmouth'],bbox['yleftmouth'],bbox['xrightmouth'],bbox['yrightmouth']]
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_feature(image_buffer),
'image/label': _int64_feature(class_label),
'image/roi': _float_feature(roi),
'image/landmark': _float_feature(landmark)
}))
return example