YOLOX训练高精度车辆检测

序言

最近项目需要用到高精度的车辆检测模型,最开始想的是直接使用yolov5或者yolox的coco预训练的模型,但是发现在实际场景中精度并不是很好,因为COCO数据集中有80个类,并且数据集的错误标签挺多的,直接白嫖到场景中会存在很多误检漏检的情况,所以还是得自己训练一个专门用来检测车辆的模型,本文记录一下训练过程。模型训练可以选择yolov5,也可以选择yolox,两者都是性能非常好的算法,本文采用yolox进行训练。

一、数据集分析

一开始我想的是直接从coco数据集中提取出car、bus、truck三个类作为车辆数据集直接进行训练,但是后来将车辆数据集分离出来后发现数据并不太理想,有很多奇奇怪怪错误的数据,索性直接放弃。然后转向公开的关于车辆检测的数据集,因为我的业务场景偏向于俯视的监控镜头,所以我决定用VisDrone数据集进行训练,VisDrone数据集是一个基于无人机视角的检测和跟踪数据集,大概的视角如下:
在这里插入图片描述
这类的场景还是比较符合我的需求,并且在监控中,基本上都是俯视的视角,再者考虑到可能还会运用到其他的场景中,所以在VisDrone数据集的基础上加上了KITTI数据集,KITTI数据集是一个自动驾驶视角的目标检测数据集,检测目标包括了行人和车辆等等:
在这里插入图片描述
因此,将这两个场景的数据合并起来一起训练,再加上部分coco数据集分离出来经过精选后的图片,基本上就可以满足所有的车辆检测场景了。如果以上场景数据集还是不满足的话可以加入更多的数据集:

行人检测数据集汇总1

行人、车辆检测数据集汇总2

相关数据集和代码提供百度云,需要的朋友可自行下载。

链接: https://pan.baidu.com/s/1njSDxDWb5gu20_AaRPi3WQ?pwd=wq7p

提取码: wq7p

二、数据集准备

因为以上两类数据集的标注格式并非我们熟悉的COCO格式或者VOC格式,所以下载下来后还需要对其标注进行解析得到相关的VOC格式标注标签,先将两个数据集进行下载,下载下来后的VisDrone数据格式label是txt文件,所以需要代码将其转成我们熟悉的VOC格式,转换代码如下:

"""
该脚本用于visdrone数据处理;
将annatations文件夹中的txt标签文件转换为XML文件;
txt标签内容为:
<bbox_left>,<bbox_top>,<bbox_width>,<bbox_height>,<score>,<object_category>,<truncation>,<occlusion>
类别:
ignored regions(0), pedestrian(1),
people(2), bicycle(3), car(4), van(5),
truck(6), tricycle(7), awning-tricycle(8),
bus(9), motor(10), others(11)
"""

import os
import cv2
import time
from xml.dom import minidom

name_dict = {'0': 'ignored regions', '1': 'pedestrian', '2': 'people',
             '3': 'bicycle', '4': 'car', '5': 'van', '6': 'truck',
             '7': 'tricycle', '8': 'awning-tricycle', '9': 'bus',
             '10': 'motor', '11': 'others'}


def transfer_to_xml(pic, txt, file_name):
    xml_save_path = r'VisDrone2019-DET-train\xml'            # 生成的xml文件存储的文件夹
    if not os.path.exists(xml_save_path):
        os.mkdir(xml_save_path)

    img = cv2.imread(pic)
    img_w = img.shape[1]
    img_h = img.shape[0]
    img_d = img.shape[2]
    doc = minidom.Document()

    annotation = doc.createElement("annotation")
    doc.appendChild(annotation)
    folder = doc.createElement('folder')
    folder.appendChild(doc.createTextNode('visdrone'))
    annotation.appendChild(folder)

    filename = doc.createElement('filename')
    filename.appendChild(doc.createTextNode(file_name))
    annotation.appendChild(filename)

    source = doc.createElement('source')
    database = doc.createElement('database')
    database.appendChild(doc.createTextNode("Unknown"))
    source.appendChild(database)

    annotation.appendChild(source)

    size = doc.createElement('size')
    width = doc.createElement('width')
    width.appendChild(doc.createTextNode(str(img_w)))
    size.appendChild(width)
    height = doc.createElement('height')
    height.appendChild(doc.createTextNode(str(img_h)))
    size.appendChild(height)
    depth = doc.createElement('depth')
    depth.appendChild(doc.createTextNode(str(img_d)))
    size.appendChild(depth)
    annotation.appendChild(size)

    segmented = doc.createElement('segmented')
    segmented.appendChild(doc.createTextNode("0"))
    annotation.appendChild(segmented)

    with open(txt, 'r') as f:
        lines = [f.readlines()]
        for line in lines:
            for boxes in line:
                box = boxes.strip('\n')
                box = box.split(',')
                x_min = box[0]
                y_min = box[1]
                x_max = int(box[0]) + int(box[2])
                y_max = int(box[1]) + int(box[3])
                object_name = name_dict[box[5]]

                # if object_name is 'ignored regions' or 'others':
                #     continue

                object = doc.createElement('object')
                nm = doc.createElement('name')
                nm.appendChild(doc.createTextNode(object_name))
                object.appendChild(nm)
                pose = doc.createElement('pose')
                pose.appendChild(doc.createTextNode("Unspecified"))
                object.appendChild(pose)
                truncated = doc.createElement('truncated')
                truncated.appendChild(doc.createTextNode("1"))
                object.appendChild(truncated)
                difficult = doc.createElement('difficult')
                difficult.appendChild(doc.createTextNode("0"))
                object.appendChild(difficult)
                bndbox = doc.createElement('bndbox')
                xmin = doc.createElement('xmin')
                xmin.appendChild(doc.createTextNode(x_min))
                bndbox.appendChild(xmin)
                ymin = doc.createElement('ymin')
                ymin.appendChild(doc.createTextNode(y_min))
                bndbox.appendChild(ymin)
                xmax = doc.createElement('xmax')
                xmax.appendChild(doc.createTextNode(str(x_max)))
                bndbox.appendChild(xmax)
                ymax = doc.createElement('ymax')
                ymax.appendChild(doc.createTextNode(str(y_max)))
                bndbox.appendChild(ymax)
                object.appendChild(bndbox)
                annotation.appendChild(object)
                with open(os.path.join(xml_save_path, file_name + '.xml'), 'w') as x:
                    x.write(doc.toprettyxml())
                x.close()
    f.close()

if __name__ == '__main__':
    t = time.time()
    print('Transfer .txt to .xml...ing....')
    txt_folder = r'VisDrone2019-DET-train\annotations'  # visdrone txt标签文件夹
    txt_file = os.listdir(txt_folder)
    img_folder = r'VisDrone2019-DET-train\images'  # visdrone 照片所在文件夹

    for txt in txt_file:
        txt_full_path = os.path.join(txt_folder, txt)
        img_full_path = os.path.join(img_folder, txt.split('.')[0] + '.jpg')

        try:
            transfer_to_xml(img_full_path, txt_full_path, txt.split('.')[0])
        except Exception as e:
            print(e)

    print("Transfer .txt to .XML sucessed. costed: {:.3f}s...".format(time.time() - t))

转换后的XML标注文件有12个类,但是我只想取其中的汽车类进行训练,VisDrone中的汽车类有:car、bus、truck、van,所以使用如下代码提取一下类别,重新生成新的数据集:

# VOC数据集提取某个类或者某些类
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import xml.etree.ElementTree as ET
import shutil

# 根据自己的情况修改相应的路径
ann_filepath = r'Annotations/'
img_filepath = r'JPEGImages/'
img_savepath = r'imgs/'
ann_savepath = r'xmls/'
if not os.path.exists(img_savepath):
    os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
    os.mkdir(ann_savepath)

# 这是VOC数据集中所有类别
# classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
#             'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
#              'dog', 'horse', 'motorbike', 'pottedplant',
#           'sheep', 'sofa', 'train', 'person','tvmonitor']

classes = ['car','bus','truck','van']  # 这里是需要提取的类别

def save_annotation(file):
    tree = ET.parse(ann_filepath + '/' + file)
    root = tree.getroot()
    result = root.findall("object")
    bool_num = 0
    for obj in result:
        if obj.find("name").text not in classes:
            root.remove(obj)
        else:
            bool_num = 1
    if bool_num:
        tree.write(ann_savepath + file)
        return True
    else:
        return False

def save_images(file):
    name_img = img_filepath + os.path.splitext(file)[0] + ".png"
    shutil.copy(name_img, img_savepath)
    # 文本文件名自己定义,主要用于生成相应的训练或测试的txt文件
    with open('train.txt', 'a') as file_txt:
        file_txt.write(os.path.splitext(file)[0])
        file_txt.write("\n")
    return True


if __name__ == '__main__':
    for f in os.listdir(ann_filepath):
        print(f)
        if save_annotation(f):
            save_images(f)

再然后,我希望把这几个类合并成一个类,统一使用car标签,类别修改合并代码如下:

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET

origin_ann_dir = r'xmls_old/'  # 设置原始标签路径为 Annos
new_ann_dir = r'xmls_new/'  # 设置新标签路径 Annotations
for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):  # os.walk游走遍历目录名
    for filename in filenames:
        print("process...")
        if os.path.isfile(r'%s%s' % (origin_ann_dir, filename)):  # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
            origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))  # 如果是,获取绝对路径(重复代码)
            new_ann_path = os.path.join(r'%s%s' % (new_ann_dir, filename))
            tree = ET.parse(origin_ann_path)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
            root = tree.getroot()  # 获取根节点
            for object in root.findall('object'):  # 找到根节点下所有“object”节点
                name = str(object.find('name').text)  # 找到object节点下name子节点的值(字符串)
                if (name in ["car","bus","truck","van"]):
                    object.find('name').text = "car"

            tree.write(new_ann_path)  # tree为文件,write写入新的文件中。

至此,VisDrone数据集制作完毕,将得到只含一个car类xml标注格式的数据集。然后再看看KITTI数据集,KITTI数据集下载链接中,只需要下载这部分即可:
在这里插入图片描述
下载下来后在ubuntu上unzip解压可能会有问题,不要慌,搜一下报错能找到答案,好像是用7z解压的,有点忘记了,可以自己搜一下。

KITTI数据集的标签也是txt文件,也需要转换一下,我需要把汽车类、行人类合并一下,最后只保留三个类Car,Pedestrian,Cyclist:

# modify_annotations_txt.py
#将原来的8类物体转换为我们现在需要的3类:Car,Pedestrian,Cyclist。
#我们把原来的Car、Van、Truck,Tram合并为Car类,把原来的Pedestrian,Person(sit-ting)合并为现在的Pedestrian,原来的Cyclist这一类保持不变。
import glob
import string
txt_list = glob.glob('kitti/data_object_image_2/training/label/label_2/*.txt')
def show_category(txt_list):
    category_list= []
    for item in txt_list:
        try:
            with open(item) as tdf:
                for each_line in tdf:
                    labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
                    category_list.append(labeldata[0]) # 只要第一个字段,即类别
        except IOError as ioerr:
            print('File error:'+str(ioerr))
    print(set(category_list)) # 输出集合
def merge(line):
    each_line=''
    for i in range(len(line)):
        if i!= (len(line)-1):
            each_line=each_line+line[i]+' '
        else:
            each_line=each_line+line[i] # 最后一条字段后面不加空格
    each_line=each_line+'\n'
    return (each_line)
print('before modify categories are:\n')
show_category(txt_list)
for item in txt_list:
    new_txt=[]
    try:
        with open(item, 'r') as r_tdf:
            for each_line in r_tdf:
                labeldata = each_line.strip().split(' ')
                if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
                if labeldata[0] == 'Person_sitting': # 合并行人类
                    labeldata[0] = labeldata[0].replace(labeldata[0],'Pedestrian')
                if labeldata[0] == 'DontCare': # 忽略Dontcare类
                    continue
                if labeldata[0] == 'Misc': # 忽略Misc类
                    continue
                new_txt.append(merge(labeldata)) # 重新写入新的txt文件
        with open(item,'w+') as w_tdf: # w+是打开原文件将内容删除,另写新内容进去
            for temp in new_txt:
                w_tdf.write(temp)
    except IOError as ioerr:
        print('File error:'+str(ioerr))
print('\nafter modify categories are:\n')
show_category(txt_list)

得到新的txt标注文件,新的label只有这三个类,然后将其装换成xml格式的标注:

# kitti_txt_to_xml.py
# encoding:utf-8
# 根据一个给定的XML Schema,使用DOM树的形式从空白文件生成一个XML
from xml.dom.minidom import Document
import cv2
import os

def generate_xml(name,split_lines,img_size,class_ind):
    doc = Document() # 创建DOM文档对象
    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)
    title = doc.createElement('folder')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    annotation.appendChild(title)
    img_name=name+'.png'
    title = doc.createElement('filename')
    title_text = doc.createTextNode(img_name)
    title.appendChild(title_text)
    annotation.appendChild(title)
    source = doc.createElement('source')
    annotation.appendChild(source)
    title = doc.createElement('database')
    title_text = doc.createTextNode('The KITTI Database')
    title.appendChild(title_text)
    source.appendChild(title)
    title = doc.createElement('annotation')
    title_text = doc.createTextNode('KITTI')
    title.appendChild(title_text)
    source.appendChild(title)
    size = doc.createElement('size')
    annotation.appendChild(size)
    title = doc.createElement('width')
    title_text = doc.createTextNode(str(img_size[1]))
    title.appendChild(title_text)
    size.appendChild(title)
    title = doc.createElement('height')
    title_text = doc.createTextNode(str(img_size[0]))
    title.appendChild(title_text)
    size.appendChild(title)
    title = doc.createElement('depth')
    title_text = doc.createTextNode(str(img_size[2]))
    title.appendChild(title_text)
    size.appendChild(title)
    for split_line in split_lines:
        line=split_line.strip().split()
        if line[0] in class_ind:
            object = doc.createElement('object')
            annotation.appendChild(object)
            title = doc.createElement('name')
            title_text = doc.createTextNode(line[0])
            title.appendChild(title_text)
            object.appendChild(title)

            title = doc.createElement('pose')
            title_text = doc.createTextNode("Unspecified")
            title.appendChild(title_text)
            object.appendChild(title)

            title = doc.createElement('truncated')
            title_text = doc.createTextNode(str(0))
            title.appendChild(title_text)
            object.appendChild(title)

            title = doc.createElement('difficult')
            title_text = doc.createTextNode(str(0))
            title.appendChild(title_text)
            object.appendChild(title)

            bndbox = doc.createElement('bndbox')
            object.appendChild(bndbox)
            title = doc.createElement('xmin')
            title_text = doc.createTextNode(str(int(float(line[4]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymin')
            title_text = doc.createTextNode(str(int(float(line[5]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('xmax')
            title_text = doc.createTextNode(str(int(float(line[6]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
            title = doc.createElement('ymax')
            title_text = doc.createTextNode(str(int(float(line[7]))))
            title.appendChild(title_text)
            bndbox.appendChild(title)
    # 将DOM对象doc写入文件
    f = open('VOCdevkit/VOCKITTI/Annotations/'+name+'.xml','w')     # xml保存路径
    f.write(doc.toprettyxml(indent='\t'))
    f.close()

if __name__ == '__main__':

    class_ind=('Pedestrian', 'Car', 'Cyclist')
    txt = "kitti/training/label"        # 新的txt
    for file_name in os.listdir(txt):
        full_path=os.path.join(txt, file_name) # 获取文件全路径
        f=open(full_path)
        split_lines = f.readlines()
        name= file_name[:-4] # 后四位是扩展名.txt,只取前面的文件名
        img_name=name+'.png'
        img_path=os.path.join('VOCdevkit/VOCKITTI/JPEGImages',img_name)   # 图片路径
        img_size=cv2.imread(img_path).shape
        print(img_path)
        generate_xml(name,split_lines,img_size,class_ind)
    print('all txts has converted into xmls')

得到xml标注格式之后,需要用之前VisDrone处理的代码,单独先将Car类提取出来,然后需要把Car标签改成car(大小写),当然这两步也可以在其他步骤里完成,最后将两个数据集的数据进行合并即可,将两个数据集的xml文件放在同一个文件夹中,图片放在同一个文件夹中。然后就可以开始训练了。

三、模型训练

yolox的模型训练需要参考我之前的文章:YOLOX自定义数据集训练,这里我就不重复写了,简单阐述一下训练过程:

  1. 划分数据集;
  2. 修改相应配置文件;
  3. 开启训练。

我使用的是tiny和nano两个模型进行训练,训练过程还是比较顺利的,训练了300个轮次,注意因为这两个数据集的难度比较大,可能训练出来的ap没这么高,不过没关系,训练完后我们测试一下就知道了。

我只使用了部分的KITTI数据集(1000张)混在VisDrone中训练,如果使用全部数据的话效果可能会比我的nano的训练效果要好一些,因为KITTI数据集要比VisDrone简单,使用tiny的话要高几个百分点,以下是nano和tiny的训练结果:
在这里插入图片描述
在这里插入图片描述

四、模型测试

因为coco数据集训练的nano和tiny用于检测车的精度都比较低,我直接拿重新训练后的nano和使用COCO预训练的s模型进行对比,输入大小都是832,首先是COCO训练的s模型,然后是重新训练后用于检测车的nano模型,
第一组对比如下:
在这里插入图片描述
在这里插入图片描述

第二组对比:
在这里插入图片描述
在这里插入图片描述
第三组对比:
在这里插入图片描述
在这里插入图片描述
跨模型对比可以看到效果还是不错的,随机在网上找了一些密集的车辆图片来测试:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
图片在这里插入图片描述
在这里插入图片描述

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值