Pascal VOC数据集转成tfrecord文件形式---------------------解读!!!

原因:我自己想对数据集进行一些预处理,看大神的代码,终于理解是怎么转成tensorflow的数据文件tfrecord!!!!

代码链接:https://github.com/HiKapok/SSD.TensorFlow

什么是TFRecord?

       TRecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在TensorFlow中快速的复制,移动,读取,存储等。对于我们普通开发者而言,我们并不需要关心这些,Tensorflow 提供了丰富的 API 可以帮助我们轻松读写 TFRecord文件。我们只关心如何使用Tensorflow生成TFRecord,并且读取它。

1.将数据保存为tfrecords

       TFRecords文件包含了tf.train.Example协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter写入到TFRecords文件。
流程:

1. 将数据填入example protocol buffer 
2. 将protocol buffer序列化为一个字符串 
3. 通过tf.python_io.TFRecordWriter将字符串写入TFRecords文件

我们首先看convert_tfrecords.py

1.给的参数:

你的数据集存放的方式:

'''How to organize your dataset folder:
  VOCROOT/
       |->VOC2007/
       |    |->Annotations/
       |    |->ImageSets/
       |    |->...
       |->VOC2012/
       |    |->Annotations/
       |    |->ImageSets/
       |    |->...
       |->VOC2007TEST/
       |    |->Annotations/
       |    |->...
'''
tf.app.flags.DEFINE_string('dataset_directory', 'D:/Program Files(x86)/PycharmProjects/VOCROOT/',
                           'All datas directory')  #所有文件的路径
tf.app.flags.DEFINE_string('train_splits', 'VOC2007, VOC2012',
                           'Comma-separated list of the training data sub-directory') #训练数据子文件夹分别的目录名字
tf.app.flags.DEFINE_string('validation_splits', 'VOC2007TEST',
                           'Comma-separated list of the validation data sub-directory')  #测试数据集子目录的名字
tf.app.flags.DEFINE_string('output_directory', 'D:/Program Files(x86)/PycharmProjects/convert-tf/dataset/tfrecords/',
                           'Output data directory')     #输出的tfrecord所在文件夹位置
tf.app.flags.DEFINE_integer('train_shards', 16,
                            'Number of shards in training TFRecord files.')  #训练tfrecord文件夹中的碎片数量16
tf.app.flags.DEFINE_integer('validation_shards', 16,
                            'Number of shards in validation TFRecord files.') #测试tfrecord文件夹中的碎片数量16
tf.app.flags.DEFINE_integer('num_threads', 8,
                            'Number of threads to preprocess the images.')  #预处理图像的线程数8

 2.将数据集转化成对应的类型:

def _int64_feature(value):
  """Wrapper for inserting int64 features into Example proto."""
  if not isinstance(value, list):   #数据类型检查
    value = [value]
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _float_feature(value):
  """Wrapper for inserting float features into Example proto."""
  if not isinstance(value, list):
    value = [value]
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _bytes_list_feature(value):
    """Wrapper for inserting a list of bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _bytes_feature(value):
  """Wrapper for inserting bytes features into Example proto."""
  if isinstance(value, six.string_types):
    value = six.binary_type(value, encoding='utf-8')
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

 3.我们从主函数看起:一步一步调用到另一个函数。

def main(unused_argv):
    #断言,%取余
  assert not FLAGS.train_shards % FLAGS.num_threads, (
      'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
  assert not FLAGS.validation_shards % FLAGS.num_threads, (
      'Please make the FLAGS.num_threads commensurate with '
      'FLAGS.validation_shards')
  print('Saving results to %s' % FLAGS.output_directory)

  # Run it!
  _process_dataset('val', FLAGS.dataset_directory, parse_comma_list(FLAGS.validation_splits), FLAGS.validation_shards)
  _process_dataset('train', FLAGS.dataset_directory, parse_comma_list(FLAGS.train_splits), FLAGS.train_shards)

 这里首先要保证train_shards/num_threads是整数,

train_shards:将训练数据集分成几块,这里我们上面给的参数是16,就是分成16块,就是最后的训练数据就是16个tfrecord格式的文件。

num_threads:是线程的数量,这里给的参数是8,采用8个线程产生数据。注意:线程数必须要被能整出train_shards和validation_shards,来保证每个线程处理的数据块数是相同的。

也就是16/8=2,每个线程处理2个tfrecord.。首先输出图片保存的路径,然后调用第一个函数_process_dataset(),我们以train举例

3.1

_process_dataset('train', FLAGS.dataset_directory, parse_comma_list(FLAGS.train_splits), FLAGS.train_shards)

parse_comma_list()是作者自己写的函数:

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值