tensorflow object_detection 用自己的数据训练目标检测模型Mobilenet

官网目标检测模块从安装到实现的所有步骤都在 ./models-master/object_detection/g3doc

1.准备数据

首先必须要准备标注好的数据(xml文件),以及训练测试文件目录(train.txt,val.txt),然后用

./models-master/object_detection/create_pascal_tf_record.py文件制作数据.record数据格式的train,val文件

我的命令行是这样的

python /home/saners/Mobilenet/makeTest/createtf.py --label_map_path=/home/saners/Mobilenet/makeTest/own_label_map.pbtxt --data_dir=/home/saners/Mobilenet/makeTest --set=train --output_path=/home/saners/Mobilenet/makeTest/train.record

python /home/saners/Mobilenet/makeTest/createtf.py --label_map_path=/home/saners/Mobilenet/makeTest/own_label_map.pbtxt --data_dir=/home/saners/Mobilenet/makeTest --set=val --output_path=/home/saners/Mobilenet/makeTest/val.record


#1我这里的createtf.py就是create_pascal_tf_record.py,因为用了自己的数据,路径什么的有些不对,做了小小的更改,其实也没改什么,还是贴一下吧


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import os

from lxml import etree
import PIL.Image
import tensorflow as tf

import sys
sys.path.append('/home/saners/Mobilenet/models-master')   #这里模块路径找不到,我手动加了一下
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                    'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                     'difficult instances')

# flags.DEFINE_string('image_puted', 'The file to storage image')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']
# YEARS = ['VOC2007', 'VOC2012', 'merged']    #这里注释了,因为是自己的数据而且本来这个地方也可有可无,这是官方下载的数据文件夹也是组成路劲的字符串

def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='JPEGImages'):
 
  img_path = os.path.join(cla, image_subdirectory, data['filename'])    #这个cla全局变量是我后面定义的我的类别文件夹名字,因为要获取文件夹名组成数据路径
  full_path = os.path.join(dataset_directory, img_path)

  with tf.gfile.GFile(full_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'JPEG':
    raise ValueError('Image format not JPEG')
  key = hashlib.sha256(encoded_jpg).hexdigest()

  width = int(data['size']['width'])
  height = int(data['size']['height'])

  xmin = []
  ymin = []
  xmax = []
  ymax = []
  classes = []
  classes_text = []
  truncated = []
  poses = []
  difficult_obj = []
  label_name=cla.split('_')[1]

  for obj in data['object']:
    difficult = bool(int(obj['difficult']))
    if ignore_difficult_instances and difficult:
      continue

    difficult_obj.append(int(difficult))
    xmin.append(float(obj['bndbox']['xmin']) / width)
    ymin.append(float(obj['bndbox']['ymin']) / height)
    xmax.append(float(obj['bndbox']['xmax']) / width)
    ymax.append(float(obj['bndbox']['ymax']) / height)
    classes_text.append(obj['name'].encode('utf8'))
    classes.append(label_map_dict[obj['name']])
    truncated.append(int(obj['truncated']))
    poses.append(obj['pose'].encode('utf8'))

  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }))
  return example

def main(_):
  if FLAGS.set not in SETS:
    raise ValueError('set must be in : {}'.format(SETS))
  # if FLAGS.year not in YEARS:                                              #这里注释了吧,对应前面注释的year
  #   raise ValueError('year must be in : {}'.format(YEARS))


  data_dir = FLAGS.data_dir
  # years = ['VOC2007', 'VOC2012']     #同上
  # if FLAGS.year != 'merged':
  #   years = [FLAGS.year]

  writer = tf.python_io.TFRecordWriter(FLAGS.output_path)

  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)


  # for year in years:
    # logging.info('Reading from PASCAL %s dataset.', year)   #同上
  class_file=os.listdir(data_dir)
  global cla                                                                                #这里定义了全局变量cla
  for cla in class_file:
    examples_path = os.path.join(data_dir, cla, 'ImageSets', 'Main', FLAGS.set + '.txt')    #  train.txt或者val.txt存放的位置
    annotations_dir = os.path.join(data_dir, cla, FLAGS.annotations_dir)                          #xml文件存放的位置
    if os.path.exists(examples_path):
      examples_list = dataset_util.read_examples_list(examples_path)
    else:
      continue

    for idx, example in enumerate(examples_list):
      if idx % 100 == 0:
        logging.info('On image %d of %d', idx, len(examples_list))
      path = os.path.join(annotations_dir, example + '.xml')
      with tf.gfile.GFile(path, 'r') as fid:
        xml_str = fid.read()
      xml = etree.fromstring(xml_str)
      data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
      tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
                                      FLAGS.ignore_difficult_instances)
      writer.write(tf_example.SerializeToString())

  writer.close()

if __name__ == '__main__':
  tf.app.run()

#2own_label_map.pbtxt 这个文件是我改了./models-master/object_detection/data/pascal_label_map.pbtxt文件中的类别,我知识了一类所以就是这样

item {
  id: 1
  name: 'aircraft'
}

#3--data_dir就是你放图像数据的跟目录

#4 --set说明数据是train还是val

#5 --output_path 就是输出.record文件的地方,文件名最好是train.record或者val.record,方便辨认


我的目录结构

+dataTest  #数据根目录

    +category_aircraft   #我只有一类,类别根目录,注意这个文件夹里面的文件夹名字请保持一致

        +Annotations  #里面是所有图像的xml文件

        +imageSets   #放置train.txt和val.txt文件的地方

            +Main       #里面是train.txt和val.txt文件,多这一层是因为官方提供的生成数据的文件里面有这个路径字符串,前面文件保持一致也是这个原因,不过你想改也行,记得改源码

               train.txt

               val.txt

       +JPEGImages #里面是图像,必须是JPEG格式的即.jpg

    own_label_map.pbtxt  #类别文件

2.准备训练文件

我的目录结构是这样的

+example

    +data

         own_label_map.pbtxt

         train.record   #训练数据,前面产生了

         val.record     #测试数据

    +models

       +model  #放置训练结果的地方

  ssd_mobilenet_v1_pets.config #模型配置文件,位置在./models-master/object_detection/samples/configs中,这个文件需要更改,你打开以后会有提示

  大概说一下num_classes改为你的类别个数,这一句我注释了fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt",因为我是重新训练模型,不是在某个模型基础上开始的,from_detection_checkpoint: true  表示检查点来自检测模型,false表示检查点来自分类模型  这是官方文档原话:`from_detection_checkpoint` is a boolean value. If false, it assumes the checkpoint was from an object classification checkpoint. Note that starting from a detection checkpoint will usually result in a faster training job than a classification checkpoint.(大神帮忙理解一下对不对)

这部分以后的更改都是对应文件的路径,就不在累述了。

num_steps: 200000和num_examples: 2000 一个是训练迭代次数一个是val 的数据量,第二个参数的说明是在其他博客上看到的,不敢保证就是对的,自己再慢慢研究吧!


以上都做完后就可以训练了,我的命令行如下:

python /home/saners/Mobilenet/models-master/object_detection/train.py \
    --logtostderr \
    --pipeline_config_path=/home/saners/Mobilenet/exampleTest2/models/ssd_mobilenet_v1_pets.config \
    --train_dir=/home/saners/Mobilenet/exampleTest2/models/model/


这个训练还在研究中,后续还会更新

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

如雾如电

随缘

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值