tensorflow入门教程(三十八)Object Detection API源码分析之将数据集转成tfrecord

#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#

1、概述

接着分析一下Object Detection API的源码,请结合前面的三篇关于Object Detection API的博客一起看,链接如下:

https://blog.csdn.net/rookie_wei/article/details/81143814

https://blog.csdn.net/rookie_wei/article/details/81210499

https://blog.csdn.net/rookie_wei/article/details/81275663

从最后一篇,训练自己的模型开始分析。

2、将VOC2012数据集转成tfrecord格式源码分析

根据命令,

python dataset_tools/create_pascal_tf_record.py --data_dir=my_images/VOCdevkit/ --year=VOC2012 --output_path=my_images/VOCdevkit/pascal_train.record --set=train

我们从dataset_tools/create_pascal_tf_record.py文件入手,看看怎么将图片数据转成tfrecord格式。从main函数开始,

def main(_):
  if FLAGS.set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))
  if FLAGS.year not in YEARS:
    raise ValueError('year must be in : {}'.format(YEARS))

看来set和year参数非设置不可,不过默认设置成train和VOC2007了。看看其他参数,定义如下,

flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                    'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                     'difficult instances')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']

比较简单不解释了,接着往下看,

data_dir = FLAGS.data_dir
years = ['VOC2007', 'VOC2012']
if FLAGS.year != 'merged':
  years = [FLAGS.year]

#用于保存TFRecord文件
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

#解析label文件(data/pascal_label_map.pbtxt),结果如下,
#{'pottedplant': 16, 'diningtable': 11, 'sheep': 17, 'aeroplane': 1, 'bicycle': 2,
# 'person': 15, 'bus': 6, 'train': 19, 'sofa': 18, 'car': 7, 'chair': 9, 'dog': 12,
# 'bottle': 5, 'bird': 3, 'motorbike': 14, 'cow': 10, 'tvmonitor': 20, 'horse': 13,
# 'boat': 4, 'cat': 8}
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

上面其实就是解析一下data/pascal_label_map.pbtxt文件,这个文件的格式如下,

item {
  id: 1
  name: 'aeroplane'
}
item {
  id: 2
  name: 'bicycle'
}
item {
  id: 3
  name: 'bird'
}
item {
  id: 4
  name: 'boat'
}
item {
  id: 5
  name: 'bottle'
}
item {
  id: 6
  name: 'bus'
}
item {
  id: 7
  name: 'car'
}
item {
  id: 8
  name: 'cat'
}
item {
  id: 9
  name: 'chair'
}
item {
  id: 10
  name: 'cow'
}
item {
  id: 11
  name: 'diningtable'
}
item {
  id: 12
  name: 'dog'
}
item {
  id: 13
  name: 'horse'
}
item {
  id: 14
  name: 'motorbike'
}
item {
  id: 15
  name: 'person'
}
item {
  id: 16
  name: 'pottedplant'
}
item {
  id: 17
  name: 'sheep'
}
item {
  id: 18
  name: 'sofa'
}
item {
  id: 19
  name: 'train'
}
item {
  id: 20
  name: 'tvmonitor'
}

这就是我们要识别的20个种类。继续,

for year in years:
  logging.info('Reading from PASCAL %s dataset.', year)

  # data_dir/VOC2012/ImageSets/Main/aeroplane_train.txt
  examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
                               'aeroplane_' + FLAGS.set + '.txt')
  # data_dir/VOC2012/Annotations
  annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)

  # 将aeroplane_train.txt文件的每行第一个字段放到数组里,其实就是不带后缀的文件名
  examples_list = dataset_util.read_examples_list(examples_path)
  print(len(examples_list))

上面因为我们传入的FLAGS.year参数是VOC2012,所以这里的year也就是’VOC2012’。所以examples_path就是”data_dir/VOC2012/ImageSets/Main/aeroplane_train.txt”文件,annotations_dir就是”data_dir/VOC2012/Annotations”文件夹。

dataset_util.read_examples_list函数就是读取aeroplane_train.txt文件,然后,将每一行的第一个字符串保存到数组里。这些字符串其实就是data_dir/VOC2012/Annotations文件夹下对应的文件名,只是不带后缀而已,aeroplane_train.txt文件内容如下,

data_dir/VOC2012/Annotations文件夹下的文件如下,

注意的是,aeroplane_train.txt文件内容并没有包含所有的data_dir/VOC2012/Annotations文件夹下的文件。接着往下看,

for idx, example in enumerate(examples_list):
  if idx % 100 == 0:
    logging.info('On image %d of %d', idx, len(examples_list))

  #获取  data_dir/VOC2012/Annotations 文件夹下对应的 20xx_xxxxxx.xml文件
  path = os.path.join(annotations_dir, example + '.xml')

  #打开上面获取的xml的文件,解析它
  with tf.gfile.GFile(path, 'r') as fid:
    xml_str = fid.read()

  #解析xml文件
  xml = etree.fromstring(xml_str)
  data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

  tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                  FLAGS.ignore_difficult_instances)

  #保存成record
  writer.write(tf_example.SerializeToString())

上面就是解析aeroplane_train.txt文件里所有的对应的data_dir/VOC2012/Annotations文件夹下的文件。我们主要看看dict_to_tf_example函数,这个函数的代码要结合data_dir/VOC2012/Annotations文件夹下的xml文件的格式来看,就很容易理解了,格式如下,

def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
  """Convert XML derived dict to tf.Example proto.

  Notice that this function normalizes the bounding box coordinates provided
  by the raw data.

  Args:
    data: dict holding PASCAL XML fields for a single image (obtained by
      running dataset_util.recursive_parse_xml_to_dict)
    dataset_directory: Path to root directory holding PASCAL dataset
    label_map_dict: A map from string label names to integers ids.
    ignore_difficult_instances: Whether to skip difficult instances in the
      dataset  (default: False).
    image_subdirectory: String specifying subdirectory within the
      PASCAL dataset directory holding the actual image data.

  Returns:
    example: The converted tf.Example.

  Raises:
    ValueError: if the image pointed to by data['filename'] is not a valid JPEG
  """

  #获取该xml对应的图片的文件, 比如 VOC2012/JPEGImages/2008_000008.jpg
  img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
  # 再加上 dataset_directory 的路径, 比如 my_images/VOCdevkit/VOC2012/JPEGImages/2008_000008.jpg
  full_path = os.path.join(dataset_directory, img_path)

首先,根据folder和filename关键字找到该xml文件对应的图片的路径。

接着往下看,

#读取图片
with tf.gfile.GFile(full_path, 'rb') as fid:
  encoded_jpg = fid.read()

encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)

if image.format != 'JPEG':
  raise ValueError('Image format not JPEG')

#图片哈希值
key = hashlib.sha256(encoded_jpg).hexdigest()

#获取图片的宽和高
width = int(data['size']['width'])
height = int(data['size']['height'])

读取图片,生成哈希值,获取图片的宽高,接着往下看,

xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
if 'object' in data:
  for obj in data['object']:
    #目标是否难以检测
    difficult = bool(int(obj['difficult']))
    if ignore_difficult_instances and difficult:
      continue

    difficult_obj.append(int(difficult))

    #获取检测框,左下角和右上角坐标
    xmin.append(float(obj['bndbox']['xmin']) / width)
    ymin.append(float(obj['bndbox']['ymin']) / height)
    xmax.append(float(obj['bndbox']['xmax']) / width)
    ymax.append(float(obj['bndbox']['ymax']) / height)

    #目标名称
    classes_text.append(obj['name'].encode('utf8'))
    #对应于label文件夹里的种类的数字
    classes.append(label_map_dict[obj['name']])
    #目标有没有被遮挡
    truncated.append(int(obj['truncated']))
    #pose
    poses.append(obj['pose'].encode('utf8'))

也是解析xml文件的一些字段,接着看,

example = tf.train.Example(features=tf.train.Features(feature={
    'image/height': dataset_util.int64_feature(height),
    'image/width': dataset_util.int64_feature(width),
    'image/filename': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/source_id': dataset_util.bytes_feature(
        data['filename'].encode('utf8')),
    'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
    'image/encoded': dataset_util.bytes_feature(encoded_jpg),
    'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
    'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
    'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
    'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
    'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
    'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
    'image/object/class/label': dataset_util.int64_list_feature(classes),
    'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
    'image/object/truncated': dataset_util.int64_list_feature(truncated),
    'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example

将上面解析的结果,传到tf.train.Example里,这里看不懂的话,可以看看下面的博客:

https://www.jianshu.com/p/b480e5fcb638

最后,转成tfrecord以后,记得关闭writer。

writer.close()

这部分比较简单,对着源码看看就可以看明白了。那里不明白就print看看。

如果您感觉本篇博客对您有帮助,请打开支付宝,领个红包支持一下,祝您扫到99元,谢谢~~

 

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值