TensorFlow Object Detection API训练自己的模型并进行识别

本文详述了在Windows环境下,利用TensorFlow Object Detection API进行目标检测的步骤,包括准备数据集、创建record文件、下载预训练模型、修改配置文件、训练、导出模型以及预测和结果展示。作者分享了数据组织结构、创建txt文件、创建record文件的代码,并提醒了在训练过程中可能遇到的问题及其解决方案。
摘要由CSDN通过智能技术生成

写在前面

本文的环境:window10、python3.7.2、anaconda3.4,TensorFlow是通过anaconda自动安装的,版本是1.3.1,然后已经安装好了TensorFlow Object Detection API。这些在我上一篇文章【TensorFlow Object Detection API 安装】有。

【1】准备自己的数据集


首先准备好自己的图片,并且规范的命名(便于后续处理),保存在img文件夹中。然后在同级目录下新建xml文件夹,用来保存标签文件。标签通过通过labelImg软件完成。labelImg网盘下载, 提取码:lyi6 。下载后直接解压就可以使用。

在使用labelImg时要注意两个路径,open dir对应选择img文件夹,change save dir对应选择xml文件夹。解压出来的文件predefined_classes.txt中,可以预先定义自己要打的标签的名称。

按【W】进行标记,【ctrl+s】保存,【a】【d】切换图片。

最后值得说一下的是,因为xml文件里面会保存文件路径信息,因此最好是确定img和xml文件夹的最终位置后再进行。

我创建的文件目录如下:

  • object_detection
    • mydata
      • data
        • img
        • xml
        • train.txt
        • val.txt

如上面所示,还需要创建两个txt文件,分别是train.txt,val.txt,用来保存训练集和验证集的文件名称(不需要文件类型)。例如我的train.txt和val.txt的内容如下:

【2】创建record文件和pbtxt文件

这个我直接把我的create_tf_recored文件贴出来,按照我之前的配置可以直接用(要灵活使用的话需要读者去看懂,然后根据自己的配置更改代码)。

r"""Convert dataset to TFRecord for object_detection.

Example usage:
    python object_detection/dataset_tools/create__tf_record.py \
        --data_dir=mydata/data \
        --set=train \
        --output_path=mydata/train.record
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import os

from lxml import etree
import PIL.Image
import tensorflow as tf

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util


flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
                    'merged set.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'mydata/mydata_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
                     'difficult instances')
FLAGS = flags.FLAGS

SETS = ['train', 'val', 'trainval', 'test']


def dict_to_tf_example(data,
                       dataset_directory,
                       label_map_dict,
                       ignore_difficult_instances=False,
                       image_subdirectory='img'):
  """Convert XML derived dict to tf.Example proto.

  Notice that this function normalizes the bounding box coordinates provided
  by the raw data.

  Args:
    data: dict holding PASCAL XML fields for a single image (obtained by
      running dataset_util.recursive_parse_xml_to_dict)
    dataset_directory: Path to root directory holding PASCAL dataset
 
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值