Tensorflow object detection API训练自己的数据

本文详细介绍了如何使用Tensorflow Object Detection API进行自定义数据的训练,包括安装、数据准备、模型配置、训练及结果测试。首先,需要安装TF Object Detection API及相关依赖,并确保版本匹配。接着,利用labelImg工具进行数据标注并转换为tfrecord格式。然后,选择预训练模型如ssd_mobilenet_v1_coco,并调整config文件以适应自己的数据。训练过程中,使用tensorboard监控损失曲线。最后,导出模型并进行效果测试。
摘要由CSDN通过智能技术生成

一. 安装

    Tensorflow object detection api是tensorflow官方出品的检测工具包,集成了像ssd、faster rcnn等检测算法,mobilenet、inception、resnet等backbone和fpn、ppn等方法,各模块之间能够通过组合的方式来work。

    Github下载地址:https://github.com/tensorflow/models

    解压models-master,主要的内容都在/research/目录下,里面有很多代码包括ocr、nlp、speech等,本节我们只需要关注object_detection,也就是目标检测部分。

    代码安装步骤如下:

    1)确保电脑上已安装了tensorflow环境+keras,建议tf版本>=1.6,注意tf版本要与CUDA、cudnn环境匹配,否则可能会有意想不到的错误。ps:tensorflow版本升级过快,前向兼容性不好(接口不一致),确实得吐槽。

    2)编译protoc (建议版本3.4+)
         cd research
         protoc --python_out=. object_detection/protos/*.proto

    3)安装research & slim
         python setup.py install
         cd slim
         python setup.py install

    4)添加系统路径(research目录)
         export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
         也可以添加到系统变量:
         vim ~/.bashrc
         export PYTHONPATH=$PYTHONPATH:/path/to/research:/path/to/research/slim
         source ~/.bashrc

    5)测试是否安装成功(research目录)
         python object_detection/builders/model_builder_test.py
         安装成功会显示ok。

二. 数据准备

    数据标注工具建议使用labelImg,采用xml进行数据保存,相关教程比较多,这里不再赘述。

    这里需要说明的是:tensorflow需要tfrecord格式数据,需要在完成的标注数据基础上进行数据转换,这里将标注数据分为两个文件夹,train和test,文件夹下包含图片文件和xml。

    需要通过脚本进行数据转换:首先将annotation转换成csv,然后将csv转换成tfrecord。

       python xml_to_csv.py /path/to/train ./data/train_labels.csv
       python xml_to_csv.py /path/to/test ./data/test_labels.csv
       
       python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=train.record
       python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=test.record

# xml_to_csv.py
# -*- coding: utf-8 -*-

import os, sys
import glob
import pandas as pd
import xml.etree.ElementTree as ET

def xml_to_csv(_path, _out_file):
    xml_list = []
    for xml_file in glob.glob(_path + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text),
                     member[0].text,
                     int(member[4][0].text),
              
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值