将CompCars数据集转换成.tfrecord格式

前言:

无论object detection 或者object classification都需要对原始图像进行相应的处理,使其生成适用于训练的数据.
object classification ---需要图像,以及该图像对应的label.
object detection ---需要图像,以及该图像对应的label和(xmin,ymin)(xmax,ymax)
所以通常数据级会有俩部分构成
annotation -存储的图像的信息,即文件名字和对应该文件名的 label和(xmin,ymin)(xmax,ymax)通常是xml的形式
image ---图像文件

常见的数据集:

The Comprehensive Cars (CompCars) dataset:汽车相关的数据集,以这个为例讲解如何生成.tfrecord格式的文件

吐槽:

----一些不成熟的小建议
<1>The Cars dataset contains 16,185 images of 196 classes of cars.很感谢CompCars提供了这些数据,但它提供的类别文件只是数字,并没有描述清楚每个数字对应的类别
<2>提供的annotation是.mat格式,解析这种格式的文件感觉不如xml格式的文件方便.
<3> 这是在下载完数据之后它的readme.md描述,这里有一个错误
readme.md描述
bbox_x1: Min x-value of the bounding box, in pixels
bbox_x2: Max x-value of the bounding box, in pixels
bbox_y1: Min y-value of the bounding box, in pixels
bbox_y2: Max y-value of the bounding box, in pixels
更正如下
bbox_x1: Min x-value of the bounding box, in pixels
bbox_x2:Min y-value of the bounding box, in pixels
bbox_y1: Max x-value of the bounding box, in pixels
bbox_y2: Max y-value of the bounding box, in pixels
*****特别注意这一点
**

数据解析:

**
下载文件
文件结构如下:

关键过程如下:

import scipy.io #导入读取文件的库
data = scipy.io.loadmat(PATH_TO_YOUR_annotations_file)#读取相应的文件
annotations = data['annotations']
#文件列表和annotations具有对应关系  
#filename将数字前面的0去掉然后在减1就是annotations的index
example_index = example.lstrip('0')#example文件名  去0
example_index = int(example_index) - 1  #-1
annotations.item(example_index)


----------
xmin = annotations[0][0][0]
ymin = annotations[1][0][0]
xmax = annotations[2][0][0]
ymax = annotations[3][0][0]
#归一化处理
xmins.append(xmin / image_width)   
ymins.append(ymin / image_height)
xmaxs.append(xmax / image_width)
ymaxs.append(ymax / image_height)

#放到feature_dict
feature_dict = {
        'image/height': dataset_util.int64_feature(image_height),
        'image/width': dataset_util.int64_feature(image_width),
        'image/filename': dataset_util.bytes_feature(
            image_filename.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            image_filename.encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/label': dataset_util.int64_list_feature(image_class),
    }

有些时候需要训练集和验证集在一个文件下,这个时候需要将该文件夹下所有文件遍历,方便训练集和验证集划分.
获取文件夹下所有文件名称,代码如下:

import os

source_folder='PATH_TO_YOUR_RESTORE/cars_train'#存储图片的目录
dest='./train.txt'#存储train.txt目录
file_list=os.listdir(source_folder)#./image/图片所在路径的文件夹列表
train_file=open(dest,'a')#打开该文件
for file_obj in file_list: #访问文件列表中所有的文件
    file_path=os.path.join(source_folder,file_obj)
    file_name,file_extend=os.path.splitext(file_obj)
    #file_name 保存文件的名字,file_extend保存文件扩展名
    train_file.write(file_name+'\n')#将文件名称写入train_file中并换行
train_file.close()#关闭文件

关于create_vehicle_tf_record.py完整代码如下

import hashlib
import io
import logging
import os
import random
import re
import argparse
from lxml import etree
import numpy as np
import PIL.Image
import tensorflow as tf
import scipy.io

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

def parse_args(check=True):
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_data_dir', type=str, default='PATH_TO_YOUR_RESTORE/dataBase/cars_train',
                        help='Root directory to import ospet dataset.')

    parser.add_argument('--train_filename_list', type=str, default='PATH_TO_YOUR_RESTORE/dataBase/devkit/train.txt',
                        help='Root directory to import ospet dataset.')

    parser.add_argument('--annotations_file', type=str, default='./cars_train_annos.mat',
                        help='Root directory to import ospet dataset.')

    parser.add_argument('--output_dir', type=str, default='./output',
                        help='Path to directory to output TFRecords.')
    parser.add_argument('--label_map_path', type=str, default='PATH_TO_YOUR_RESTORE/dataBase/quiz-w8-data/labels_items.txt',
                        help='Path to label map proto.')

    FLAGS, unparsed = parser.parse_known_args()

    return FLAGS, unparsed

def dict_to_tf_example(example , image_dir,annotations):
    filename = example + '.jpg'
    img_path = os.path.join(image_dir, filename)
    with tf.gfile.GFile(img_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = PIL.Image.open(encoded_jpg_io)
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
        logging.warning('This is warning message')
    key = hashlib.sha256(encoded_jpg).hexdigest()
    image_width = image.size[0]
    image_height = image.size[1]
    print ('image_width',image_width)
    print ('image_height',image_height)
    image_class = []

    xmins = []
    ymins = []
    xmaxs = []
    ymaxs = []

    xmin = annotations[0][0][0]
    ymin = annotations[1][0][0]
    xmax = annotations[2][0][0]
    ymax = annotations[3][0][0]
    print ('annotations[0][0][0]',annotations[0][0][0])
    print ('annotations[1][0][0]',annotations[1][0][0])
    print ('annotations[2][0][0]',annotations[2][0][0])
    print ('annotations[3][0][0]',annotations[3][0][0])
    image_class.append(int(annotations[4][0][0]))
    image_filename = annotations[5][0]
    print ('filename',image_filename)
    print ('image_class',image_class)
    print ('the type of image_class',type(image_class))
    xmins.append(xmin / image_width)
    ymins.append(ymin / image_height)
    xmaxs.append(xmax / image_width)
    ymaxs.append(ymax / image_height)
    print ('xmins:',xmins)
    print ('ymins:',ymins)
    print ('xmaxs:',xmaxs)
    print ('ymaxs:',ymaxs)

    feature_dict = {
        'image/height': dataset_util.int64_feature(image_height),
        'image/width': dataset_util.int64_feature(image_width),
        'image/filename': dataset_util.bytes_feature(
            image_filename.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            image_filename.encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/label': dataset_util.int64_list_feature(image_class),
    }
    tf_example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return tf_example

def create_tf_record(output_filename,image_dir,examples_list,annotations):
    writer = tf.python_io.TFRecordWriter(output_filename)
    for idx, example in enumerate(examples_list):
        if idx % 100 == 0:
            logging.warning('On image %d of %d', idx, len(examples_list))
        example_index = example.lstrip('0')
        print ('example_index',example_index)
        example_index = int(example_index) - 1
        try:
            tf_example = dict_to_tf_example(example , image_dir,annotations.item(example_index))
            writer.write(tf_example.SerializeToString())
        except ValueError:
            logging.info('Invalid example: %s, ignoring.', xml_path)

    writer.close()

def main(_):
    FLAGS, unparsed = parse_args()
    train_output_path = os.path.join(FLAGS.output_dir, 'vehicle_train.record')
    test_output_path = os.path.join(FLAGS.output_dir, 'vehicle_test.record')
    image_dir = FLAGS.train_data_dir
    train_filename = FLAGS.train_filename_list
    data = scipy.io.loadmat(FLAGS.annotations_file)
    annotations = data['annotations']
    #train_filename,一个文件列表,存储着所有文件的名字
    examples_list = dataset_util.read_examples_list(train_filename)

    random.seed(42)
    random.shuffle(examples_list)
    num_examples = len(examples_list)
    num_train = int(0.7 * num_examples)
    train_examples = examples_list[:num_train]
    val_examples = examples_list[num_train:]

    create_tf_record(train_output_path,image_dir,train_examples,annotations)
    create_tf_record(test_output_path,image_dir,val_examples,annotations)
if __name__ == '__main__':

    tf.app.run()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值