SSD算法中TFRecord的格式数据生成

目录

1、为什么使用TFRecord格式的文件

2、生成TFRecord文件的主要流程

3、SSD算法中生成TFRecord格式数据主要代码


1、为什么使用TFRecord格式的文件

在深度学习模型训练时,通常我们使用的数据量往往都会很大,这么多数据会占用很大的磁盘空间,并且在被一个个读取的时候会很慢、很繁琐,占用大量内存空间(有的大型数据不足以一次性加载)。TFRecord格式文件内部使用了“Protocol Buffer”二进制数据编码方案,它只需占用一个内存块,只需要一次性加载一个二进制文件。当数据比较大的时候,我们可以把数据制作成多个TFRecord文件,来提高处理效率。

2、生成TFRecord文件的主要流程

1)使用 tf.python_io.TFRecordWriter 函数生成一个TFRecord的生成器 

tf_filename = _get_output_filename(output_dir,name,fidx)
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer

这里的output_dir参数是生成.tfrecord文件的路径

2)使用tf.train.Example函数生成example

example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': int64_feature(shape[0]),
        'image/width': int64_feature(shape[1]),
        'image/channels': int64_feature(shape[2]),
        'image/shape': int64_feature(shape),
        'image/object/bbox/xmin': float_feature(xmin),
        'image/object/bbox/xmax': float_feature(xmax),
        'image/object/bbox/ymin': float_feature(ymin),
        'image/object/bbox/ymax': float_feature(ymax),
        'image/filename': int64_feature(xml_pic_name),
        'image/object/bbox/label': int64_feature(labels),
        'image/object/bbox/label_text': bytes_feature(labels_text),
        'image/object/bbox/difficult': int64_feature(difficult),
        'image/object/bbox/truncated': int64_feature(truncated),
        'image/format': bytes_feature(image_format),
        'image/encoded': bytes_feature(image_data)}))

上面写入example中的数据格式主要有三种,分别是 BytesListFloatList,Int64List。其定义如下代码段:

def int64_feature(value):
    if not isinstance(value,list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value):
    if not isinstance(value,list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value):
    if not isinstance(value,list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

 3)通过 tfrecord_writer.write(example.SerializeToString()) 语句将数据写入.tfrecord文件

 

3、SSD算法中生成TFRecord格式数据主要代码

import tensorflow as tf
import os
import sys
import random
import xml.etree.ElementTree as ET
from pascalvoc_common import VOC_LABELS
from dataset_utils import int64_feature,float_feature,bytes_feature
DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = '/JPEGImages/'

RANDOM_SEED = 4242
SAMPLES_PER_FILES = 1


def _get_output_filename(output_dir, name, idx):

    return '%s/%s_%03d.tfrecord ' % (output_dir,name,idx)


def _process_image(directory, name,i):
    """
    此程序是处理图片和xml文件

    :param dataset_dir: 数据集路径
    :param name: 要处理的图片名

    :return:

    """
    filename = directory + DIRECTORY_IMAGES + name + '.jpg' #图片的名字
    print(filename)

    image_data = tf.gfile.FastGFile(filename, 'rb').read() #对图片进行读取

    filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml') #xml的名字
    tree = ET.parse(filename)  #解析xml文件
    root = tree.getroot()

    size = root.find('size')
    print(size)
    shape = [int(size.find('height').text),
             int(size.find('width').text),
             int(size.find('depth').text)] # 从xml中获取图片的宽、高、和深度 [512, 767, 3]
    print(shape)
    xml_pic_name = i+1

    bboxes = []
    labels = []
    labels_text = []  #[b'crack', b'crack', b'crack']
    difficult = []
    truncated = []
    all_object = root.findall('object')
    for obj in all_object:
        label = obj.find('name').text
        labels.append(int(VOC_LABELS[label][0])) #从这里得到crac标签前面的数字1,这里1就等于crack
        labels_text.append(label.encode('ascii')) #将label为crack变为[b'crack', b'crack', b'crack']

        #不知到下面这两个判断是什么意思
        if obj.find('difficult'):
            difficult.append(int(obj.find('difficult').text))
        else:
            difficult.append(0)
        if obj.find('truncated'):
            truncated.append(int(obj.find('truncated').text))
        else:
            truncated.append(0)

        bbox = obj.find('bndbox')
        ymin = float(bbox.find('ymin').text)/shape[0] #从xml中得到框的ymin再除以图片的高度
        xmin = float(bbox.find('xmin').text)/shape[1] #从xml中得到框的xmin再除以图片的高度
        ymax = float(bbox.find('ymax').text)/shape[0] #从xml中得到框的ymax再除以图片的高度
        xmax = float(bbox.find('xmax').text)/shape[1] #从xml中得到框的xmax再除以图片的高度

        bboxes.append((ymin,xmin,ymax,xmax))

    return image_data,shape,bboxes,labels,labels_text, difficult, truncated,xml_pic_name


def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
                        difficult, truncated,xml_pic_name):
    """

    :param image_data: 读取的图片
    :param labels: 图片中标框的标签,这里使用数字对应表示的
    :param labels_text: 标签编码转换成文本 [b'crack', b'crack', b'crack']
    :param bboxes: 从xml中得到框的坐标后同图片的尺度缩放后的坐标
    :param shape: 图片的高、宽、深度
    :param difficult:
    :param truncated:
    :return:
    """
    #xml_pic_name = xml_pic_name.encode()
    xmin = [] #一张图片中所有框的xmin都放在这里
    ymin = []
    xmax = []
    ymax = []
    for i,b in enumerate(bboxes):
        assert len(b) == 4
        ymin.append(b[0])
        xmin.append(b[1])
        ymax.append(b[2])
        xmax.append(b[3])

        #[l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        #print("第%i个框" % i)
        # for l,point in zip([ymin, xmin, ymax, xmax],b):
        #    l.append(point)
    print(ymin)
    print(xmin)
    print(ymax)
    print(xmax)


    image_format = b'JPEG'
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': int64_feature(shape[0]),        #'image/filename'这条语句是我
        'image/width': int64_feature(shape[1]),         #自己后来加上的,目的是想提取xml
        'image/channels': int64_feature(shape[2]),      #文件里filename这个标签,由于名
        'image/shape': int64_feature(shape),            #字是字符串,我在取到后用
        'image/object/bbox/xmin': float_feature(xmin),  #xml_pic_name.encode()进行了编
        'image/object/bbox/xmax': float_feature(xmax),  #码,然后这里用的类型为
        'image/object/bbox/ymin': float_feature(ymin),  #bytes_feature但是在读取数据时
        'image/object/bbox/ymax': float_feature(ymax),  #却读不出来,一直提示数据类型错
        'image/filename': int64_feature(xml_pic_name),  # 误,一直没有解决这个问题           
        'image/object/bbox/label': int64_feature(labels),
        'image/object/bbox/label_text': bytes_feature(labels_text),
        'image/object/bbox/difficult': int64_feature(difficult),
        'image/object/bbox/truncated': int64_feature(truncated),
        'image/format': bytes_feature(image_format),
        'image/encoded': bytes_feature(image_data)}))


    return example

def _add_to_tfrecord(dataset_dir, name, tfrecord_writer,i):
    """
    从image和annotation文件中加载数据,并且将数据添加到TFRecord

    Args:
        :param dataset_dir: 数据集路径
        :param name: 添加到TFRecord中的图片的名字
        :param tfrecord_writer: 用于生成TFRecord文件的TFRecord生成器

        :return:

    """

    image_data, shape, bboxes, labels, labels_text, difficult, truncated,xml_pic_name = \
        _process_image(dataset_dir, name,i)
    example = _convert_to_example(image_data, labels, labels_text,
                                  bboxes, shape, difficult, truncated,xml_pic_name)
    tfrecord_writer.write(example.SerializeToString())

def run(dataset_dir, output_dir, name='voc_train', shuffling=False):
    """
    Run the conversion operation

    Args:
        :param dataset_dir:
        :param output_dir:
        :param name:
        :param shuffling:
        :return:
    """

    if not tf.gfile.Exists(dataset_dir):
        tf.gfile.MakeDirs(dataset_dir)

    path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
    filenames = sorted(os.listdir(path))
    #print(filenames)
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)

    i = 0
    fidx = 0
    while i < len(filenames):
        #打开新的tfrecord文件
        tf_filename = _get_output_filename(output_dir,name,fidx)
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:  #SAMPLES_PER_FILES张图片转化成一个tfrecord文件里
                sys.stdout.write('\r>> 转化图片的进度 %d/%d' % (i+1,len(filenames)))
                sys.stdout.flush()

                filename = filenames[i]
                img_name = filename[:-4]
                _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer,i)
                i += 1
                j += 1

            fidx += 1

    print("所有数据全部转化成了TFRecord的格式")

 xml文件

<?xml version="1.0" encoding="UTF-8"?>

-<annotation>

<folder>108MSDCF</folder>

<filename>DSC00001</filename>

<path>E:\tytd\Drone-pictures\Cracked-picture\108MSDCF\DSC00001.jpg</path>


-<source>

<database>Unknown</database>

</source>


-<size>

<width>767</width>

<height>512</height>

<depth>3</depth>

</size>

<segmented>0</segmented>


-<object>

<name>crack</name>

<pose>Unspecified</pose>

<truncated>0</truncated>

<difficult>0</difficult>


-<bndbox>

<xmin>485</xmin>

<ymin>29</ymin>

<xmax>523</xmax>

<ymax>45</ymax>

</bndbox>

</object>


-<object>

<name>crack</name>

<pose>Unspecified</pose>

<truncated>0</truncated>

<difficult>0</difficult>


-<bndbox>

<xmin>529</xmin>

<ymin>169</ymin>

<xmax>577</xmax>

<ymax>186</ymax>

</bndbox>

</object>


-<object>

<name>crack</name>

<pose>Unspecified</pose>

<truncated>0</truncated>

<difficult>0</difficult>


-<bndbox>

<xmin>494</xmin>

<ymin>207</ymin>

<xmax>573</xmax>

<ymax>231</ymax>

</bndbox>

</object>

</annotation>

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值