SSD-Tensorflow项目源码学习:将数据集转化为为TFR文件

TF-Record介绍

tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。
tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。你可以写一段代码获取你的数据, 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。
TFRecords文件格式在图像识别中有很好的使用,其可以将二进制数据和标签数据(训练的类别标签)数据存储在同一个文件中,它可以在模型进行训练之前通过预处理步骤将图像转换为TFRecords格式,此格式最大的优点实践每幅输入图像和与之关联的标签放在同一个文件中,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取。
以下是一个示例:

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }

TF-Record的具体说明参考:https://blog.csdn.net/u012759136/article/details/52232266

数据格式介绍

VOC2012文件夹组成
如图,JPEGImages文件夹存放图片,Annotations文件夹存放标签、bbox等信息。
JPEGImages
Annotations
xml文件的具体内容示例:

<annotation>
    <folder>VOC2012</folder>
    <filename>2007_000129.jpg</filename>
    <source>
        <database>The VOC2007 Database</database>
        <annotation>PASCAL VOC2007</annotation>
        <image>flickr</image>
    </source>
    <size>
        <width>334</width>
        <height>500</height>
        <depth>3</depth>
    </size>
    <segmented>1</segmented>
    <object>
        <name>bicycle</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>70</xmin>
            <ymin>202</ymin>
            <xmax>255</xmax>
            <ymax>500</ymax>
        </bndbox>
    </object>
    <object>
        <name>bicycle</name>
        <pose>Unspecified</pose>
        <truncated>1</truncated>
        <difficult>1</difficult>
        <bndbox>
            <xmin>251</xmin>
            <ymin>242</ymin>
            <xmax>334</xmax>
            <ymax>500</ymax>
        </bndbox>
    </object>
    <object>
        <name>bicycle</name>
        <pose>Unspecified</pose>
        <truncated>1</truncated>
        <difficult>1</difficult>
        <bndbox>
            <xmin>1</xmin>
            <ymin>144</ymin>
            <xmax>67</xmax>
            <ymax>436</ymax>
        </bndbox>
    </object>
    <object>
        <name>person</name>
        <pose>Unspecified</pose>
        <truncated>1</truncated>
        <difficult>1</difficult>
        <bndbox>
            <xmin>1</xmin>
            <ymin>1</ymin>
            <xmax>66</xmax>
            <ymax>363</ymax>
        </bndbox>
    </object>
    <object>
        <name>person</name>
        <pose>Frontal</pose>
        <truncated>1</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>74</xmin>
            <ymin>1</ymin>
            <xmax>272</xmax>
            <ymax>462</ymax>
        </bndbox>
    </object>
    <object>
        <name>person</name>
        <pose>Unspecified</pose>
        <truncated>1</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>252</xmin>
            <ymin>19</ymin>
            <xmax>334</xmax>
            <ymax>487</ymax>
        </bndbox>
    </object>
</annotation>

Python中解析XML文件

Python中使用模块xml.etree.ElementTree来解析XML文件

TFR数据生成

根据Github上的说明,当前SSD-TensorFlow项目仅支持Pascal VOC dataset(2007 or 2012)。为了能在SSD模型训练中使用,数据集需要被转化为TF-Record,通过tf_convert_data.py部分代码。

DATASET_DIR=./VOC2007/test/
OUTPUT_DIR=./tfrecords
python tf_convert_data.py \
    --dataset_name=pascalvoc \
    --dataset_dir=${DATASET_DIR} \
    --output_name=voc_2007_train \
    --output_dir=${OUTPUT_DIR}

tf_convert_data.py

def main(_):
    if not FLAGS.dataset_dir:
        raise ValueError('You must supply the dataset directory with --dataset_dir')
    print('Dataset directory:', FLAGS.dataset_dir)
    print('Output directory:', FLAGS.output_dir)

    if FLAGS.dataset_name == 'pascalvoc':
        pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)
    else:
        raise ValueError('Dataset [%s] was not recognized.' % FLAGS.dataset_name)

该段代码调用了pascalvoc_to_tfrecords.run

pascalvoc_to_tfrecords.py

def run(dataset_dir, output_dir, name='voc_train', shuffling=False):
    """Runs the conversion operation.

    Args:
      dataset_dir: The dataset directory where the dataset is stored.
      output_dir: Output directory.
    """
    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    # Dataset filenames, and shuffling.
    path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)  # 标签存放的路径
    filenames = sorted(os.listdir(path))   # 排序
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames) # shuffle 将序列内元素打乱

    # Process dataset files.
    i = 0
    fidx = 0
    while i < len(filenames):
        # Open new TFRecord file.
        tf_filename = _get_output_filename(output_dir, name, fidx)     # 获取输出文件名
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:  # 一个文件200张图
                sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))  # 用于多线程输出
                sys.stdout.flush()  # 强制刷新缓冲区 这两行不会生成多行报告 而是在一行不断刷新

                filename = filenames[i]
                img_name = filename[:-4]  # path路径下的文件名是'img_name.xml'
                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)  
                i += 1
                j += 1
            fidx += 1

    # Finally, write the labels file:
    # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
    # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
    print('\nFinished converting the Pascal VOC dataset!')

该段代码调用了_get_output_filename_add_to_tfrecord

def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
    """Loads data from image and annotations files and add them to a TFRecord.

    Args:
      dataset_dir: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
        _process_image(dataset_dir, name)                                # 读取数据集
    example = _convert_to_example(image_data, labels, labels_text,      
                                  bboxes, shape, difficult, truncated)   # 写入数据集
    tfrecord_writer.write(example.SerializeToString())


def _get_output_filename(output_dir, name, idx):
    return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)

_add_to_tfrecord 包含了读和写两步,先看_process_image 代码

def _process_image(directory, name):
    """Process a image and annotation file.

    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
    Returns:
      image_buffer: string, JPEG encoding of RGB image.
      height: integer, image height in pixels.
      width: integer, image width in pixels.
    """
    # Read the image file.
    filename = directory + DIRECTORY_IMAGES + name + '.jpg'
    image_data = tf.gfile.FastGFile(filename, 'r').read()     # 应该是rb

    # Read the XML annotation file.
    filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    tree = ET.parse(filename)  # 返回解析树
    root = tree.getroot()      # 获取根节点

    # Image shape.
    size = root.find('size')   # 查找第一个匹配的子元素
    shape = [int(size.find('height').text),
             int(size.find('width').text),
             int(size.find('depth').text)]
    # Find annotations.
    bboxes = []
    labels = []
    labels_text = []
    difficult = []
    truncated = []
    for obj in root.findall('object'):   # 返回所有匹配的子元素列表
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0]))     # VOC_LABELS 是一个字典 详情去看Pascalvpc_common.py
        labels_text.append(label.encode('ascii'))

        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
    return image_data, shape, bboxes, labels, labels_text, difficult, truncated
def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
                        difficult, truncated):
    """Build an Example proto for an image example.

    Args:
      image_data: string, JPEG encoding of RGB image;
      labels: list of integers, identifier for the ground truth;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    """
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-assigned

    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/object/bbox/difficult': int64_feature(difficult),
            'image/object/bbox/truncated': int64_feature(truncated),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
    return example
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值