原因:我自己想对数据集进行一些预处理,看大神的代码,终于理解是怎么转成tensorflow的数据文件tfrecord!!!!
代码链接:https://github.com/HiKapok/SSD.TensorFlow
什么是TFRecord?
TRecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在TensorFlow中快速的复制,移动,读取,存储等。对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord文件。我们只关心如何使用Tensorflow生成TFRecord,并且读取它。
1.将数据保存为tfrecords
TFRecords文件包含了tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。
流程:
1. 将数据填入example protocol buffer
2. 将protocol buffer
序列化为一个字符串
3. 通过tf.python_io.TFRecordWriter
将字符串写入TFRecords
文件
我们首先看convert_tfrecords.py
1.给的参数:
你的数据集存放的方式:
'''How to organize your dataset folder: VOCROOT/ |->VOC2007/ | |->Annotations/ | |->ImageSets/ | |->... |->VOC2012/ | |->Annotations/ | |->ImageSets/ | |->... |->VOC2007TEST/ | |->Annotations/ | |->... '''
tf.app.flags.DEFINE_string('dataset_directory', 'D:/Program Files(x86)/PycharmProjects/VOCROOT/',
'All datas directory') #所有文件的路径
tf.app.flags.DEFINE_string('train_splits', 'VOC2007, VOC2012',
'Comma-separated list of the training data sub-directory') #训练数据子文件夹分别的目录名字
tf.app.flags.DEFINE_string('validation_splits', 'VOC2007TEST',
'Comma-separated list of the validation data sub-directory') #测试数据集子目录的名字
tf.app.flags.DEFINE_string('output_directory', 'D:/Program Files(x86)/PycharmProjects/convert-tf/dataset/tfrecords/',
'Output data directory') #输出的tfrecord所在文件夹位置
tf.app.flags.DEFINE_integer('train_shards', 16,
'Number of shards in training TFRecord files.') #训练tfrecord文件夹中的碎片数量16
tf.app.flags.DEFINE_integer('validation_shards', 16,
'Number of shards in validation TFRecord files.') #测试tfrecord文件夹中的碎片数量16
tf.app.flags.DEFINE_integer('num_threads', 8,
'Number of threads to preprocess the images.') #预处理图像的线程数8
2.将数据集转化成对应的类型:
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list): #数据类型检查
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""Wrapper for inserting float features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_list_feature(value):
"""Wrapper for inserting a list of bytes features into Example proto.
"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
if isinstance(value, six.string_types):
value = six.binary_type(value, encoding='utf-8')
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
3.我们从主函数看起:一步一步调用到另一个函数。
def main(unused_argv):
#断言,%取余
assert not FLAGS.train_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
assert not FLAGS.validation_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with '
'FLAGS.validation_shards')
print('Saving results to %s' % FLAGS.output_directory)
# Run it!
_process_dataset('val', FLAGS.dataset_directory, parse_comma_list(FLAGS.validation_splits), FLAGS.validation_shards)
_process_dataset('train', FLAGS.dataset_directory, parse_comma_list(FLAGS.train_splits), FLAGS.train_shards)
这里首先要保证train_shards/num_threads是整数,
train_shards:将训练数据集分成几块,这里我们上面给的参数是16,就是分成16块,就是最后的训练数据就是16个tfrecord格式的文件。
num_threads:是线程的数量,这里给的参数是8,采用8个线程产生数据。注意:线程数必须要被能整出train_shards和validation_shards,来保证每个线程处理的数据块数是相同的。
也就是16/8=2,每个线程处理2个tfrecord.。首先输出图片保存的路径,然后调用第一个函数_process_dataset(),我们以train举例
3.1
_process_dataset('train', FLAGS.dataset_directory, parse_comma_list(FLAGS.train_splits), FLAGS.train_shards)
parse_comma_list()是作者自己写的函数: