Object Detection API 安装以及使用总结(二)

使用Object Detection API训练自己的数据

迟来的填坑
1.数据标注和存放
一般目标检测的标注工具使用labelImg,每一张图会对应生成一个XML文件。网上博客大多采用PASCAL VOC标准的文件存放方式来放置文件的位置,实际用的时候我觉得比较冗余,这里只需要准备以下数据:
在这里插入图片描述

  • JPEGImages
    存放原始的训练数据(图片)
  • Annotations
    存放标记生成的xml文档
  • .pbtxt文件
    根据标记的类别名称,将所有的数据类别手动按如下格式的编辑.pbtxt格式的文件
    在这里插入图片描述

2.数据格式转化
为了训练的高效处理,Object Detection API 需要将训练数据转化成tfrecord
按照上述形式放置好文件后使用如下代码进行格式转换

# create by guanqiuyu 
#
# 2018.10.26
#
# Harbin Institute of Technology Visual Technology Labrary
# ==============================================================

import argparse
import tensorflow as tf
import logging
import os
import io
import PIL.Image
import hashlib
import time
import random

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from lxml import etree

parser = argparse.ArgumentParser()

parser.add_argument('--ratio', required = True, type = float ,help = 'the ratio of train and val')
parser.add_argument('--path_data', required = True, help = 'the dir of your training data')
parser.add_argument('--path_label_dict', required = True, help = 'the dir of your .pbxt')
                    
opt = parser.parse_args()

SETS = ['train','val','trainval','test']
def IsSubString(SubStrList, Str):
    flag = True
    for substr in SubStrList:
        if not (substr in Str):
            flag = False

    return flag

def GetFileList(FindPath, FlagStr=[]):
    '''get list of all id in the dir'''
    import os
    FileList = []
    FileNames = os.listdir(FindPath)
    if (len(FileNames) > 0):
        for fn in FileNames:
            if (len(FlagStr) > 0):
                if (IsSubString(FlagStr, fn)):
                    FileList.append(fn[:-4])
            else:
                FileList.append(fn)

    if (len(FileList) > 0):
        FileList.sort()

    return FileList

def partial_data_set(path_usr,path,ratio):
    '''split data set'''
    f = open(path)
    train_list = open(path_usr + '/train.txt', 'w+')
    val_list = open(path_usr + '/val.txt', 'w+')

    for line in f:
        if random.random()<ratio:
            val_list.write(line)
        else:
            train_list.write(line)

    f.close()
    train_list.close()
    val_list.close()
    
def get_list(path_txt,path_user):
    '''Generate the list of Image data'''
    with open(path_txt, 'w+') as list_file:  # 数据集的图片list保存路径
        if not os.path.exists('%s/Annotations/' % path_user):
            os.makedirs('%s/Annotations/' % path_user)
        image_ids = GetFileList(path_user + '/Annotations/', ['xml'])
        for image_id in image_ids:
            print(image_id)
            list_file.write(image_id+'\n')

def covert_data_to_tfrecord(data,label_dict,path_data,pathname = 'JPEGImages'):
    img_path = os.path.join(path_data,pathname,data['filename'])
    with open(img_path,'rb') as f:
        image_data = f.read()
    # image_data_io = io.BytesIO(image_data)
    # image = PIL.Image.open(image_data_io)
    '''sha256加密'''
    key = hashlib.sha256(image_data).hexdigest()
    width = int(data['size']['width'])
    height = int(data['size']['height'])

    xmax = []
    xmin = []
    ymax = []
    ymin = []
    classes = []
    classes_text = []
    truncated = []
    poses = []
    difficult_obj = []

    if 'object' in data:
        for obj in data['object']:
            difficult = bool(int(obj['difficult']))
            difficult_obj.append(int(difficult))

            xmin.append(float(obj['bndbox']['xmin']) / width)
            ymin.append(float(obj['bndbox']['ymin']) / height)
            xmax.append(float(obj['bndbox']['xmax']) / width)
            ymax.append(float(obj['bndbox']['ymax']) / height)

            classes_text.append(obj['name'].encode('utf8'))
            classes.append(label_dict[obj['name']])
            truncated.append(int(obj['truncated']))
            poses.append(obj['pose'].encode('utf8'))
    
    example = tf.train.Example(features = tf.train.Features(feature = {
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            data['filename'].encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(image_data),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
        'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
        'image/object/truncated': dataset_util.int64_list_feature(truncated),
        'image/object/view': dataset_util.bytes_list_feature(poses),
    }))
    return example





def main():
    # if opt.mode not in SETS:
    #     raise ValueError('Please input true value from {}'.format(SETS))
    #create dataset.txt
    path_data = opt.path_data
    path_dataset_txt = os.path.join(path_data,'dataset.txt')
    get_list(path_dataset_txt,path_data)
    #use the dataset.txt to generate train.txt and val.txt
    partial_data_set(path_data,path_dataset_txt,opt.ratio)
    datas = ['train','val']
    #get dict from the .pbxt
    label_dict = label_map_util.get_label_map_dict(opt.path_label_dict)
    for i_data , data in enumerate(datas):
        writer = tf.python_io.TFRecordWriter(os.path.join(path_data,data + '.tfrecords'))
        print(i_data ,'------', data)
        logging.info('open the '+ data + '.txt')

        datalist_from_txt = dataset_util.read_examples_list(os.path.join(path_data,data+'.txt')) 
        for i,datalist in enumerate(datalist_from_txt):
            if i % 100 == 0:
                logging.info('image %d ----total %d',i,len(datalist_from_txt))
            # data_name = datalist.split('.')[0].split('/')[-1]
            path = os.path.join(path_data,'Annotations', datalist + '.xml')
            with open(path,'r') as f:
                xml_str = f.read()
            xml_content = etree.fromstring(xml_str)
            data = dataset_util.recursive_parse_xml_to_dict(xml_content)['annotation']
            tf_example = covert_data_to_tfrecord(data,label_dict,path_data)
            writer.write(tf_example.SerializeToString())
            
        writer.close()

    
if __name__ == '__main__':
    # print(time.strftime("%b+ %B +%a %A",time.localtime()))
    main()

执行命令python xml_to_record.py --ratio #R --path_data #D --path_label_dict #T.pbxt

  • #R —— 分割数据集的比例 ,如训练集:验证集 = 9:1 ,则设为0.1
  • #D —— 存放Annotations 和 JPEGImages的文件目录
  • #T —— .pbxt文件的地址

比如我的设置为python xml_to_record.py --ratio 0.1 --path_data /home/stargrain/data/train --path_label_dict /home/stargrain/data/train/label_map.pbxt

执行不报错,会生成train.tfrecords 和 val.tfrecords 两个文件

第二步完成

3.下载模型并训练

我是地址
在model_zoos选择需要的网络模型,下载并解压,更改pipeline.config文件
num_classes 改为自己数据的类别数
在这里插入图片描述
最下面的路径,finetune的权重、.pbxt文件、两个tfrecods文件 的路径修改为自己的
最后cd 到models/research/object_detection/legacy目录下找到train.py文件
使用python train.py --logtostderr --train_dir=输出目录地址 --pipeline_config_path=刚刚修改的config文件地址

PS.
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!
注意这里输出目录千万不要和加载的权重文件在同一个目录下,否则会报错!!!

以上

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值