YOLOX 训练自己的数据集

YOLOX 是旷视提出的一个新的高性能检测器,从论文来看,backbone和neck相比前作变化不大,主要更改在于输入端的数据增强加入了mixup,预测层相比前作改为了解耦的方式。

本文将针对一个Finger识别项目,介绍ubuntu命令行下,yoloX从环境搭建到模型训练的整个过程。

算法原理阅读:
GitHUB
论文
CVer

1.环境

训练使用的环境如下:
Ubuntu 20.04
python 3.8.8
torch 1.8.2+cu111
torchvision 0.9.2+cu111
NVIDIA driver 470.74
CUDA 11.5

相关安装和校验见上一篇文章

2.代码模型下载

2.1 下载模型

git clone https://github.com/Megvii-BaseDetection/YOLOX

2.2下载预训练模型
yolox_x的预训练模型

wget https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth

2.3 下载模型依赖项

cd YOLOX
pip3 install -r requirements.txt

如果报错,打开根目录下的requirements.txt,手动安装

sudo apt install xxx

3.数据集制作

3.1 Yolomark标注
Yolomark的使用参考上一篇文章

YOLOX使用的是VOC COCO数据集,制作完之后我们需要转换yolomark格式到VOC格式。当然,可以一开始就使用VOC相关工具生成VOC的格式。

4.训练设置

4.1 数据转移
a)在YOLOX/dataset文件下新建一个文件夹FingerDetect用于存放我们制作的数据集。在其中新建VOC格式的4个文件夹 Annotations,ImageSets,JPEGImages。再新建一个labels文件夹。
在这里插入图片描述
b)在ImageSets下新建一个Main文件夹,里面新建2个txt文件:train.txtval.txt ,用于稍后存储分组信息。

在这里插入图片描述

c)将用yolomark标注的jpg文件放入JPEGImages,标签txt放入labels。

d)新建一个.py文件,仿照以下代码编写,将.txt标签转化为.xml标签,注意路径设置。笔者这里的转换相对路径见代码后附的图片。

import os
import xml.etree.ElementTree as ET
from xml.dom.minidom import Document
import cv2

'''
import xml
xml.dom.minidom.Document().writexml()
def writexml(self,
             writer: Any,
             indent: str = "",
             addindent: str = "",
             newl: str = "",
             encoding: Any = None) -> None
'''

class YOLO2VOCConvert:
    def __init__(self, txts_path, xmls_path, imgs_path):
        self.txts_path = txts_path   # 标注的yolo格式标签文件路径
        self.xmls_path = xmls_path   # 转化为voc格式标签之后保存路径
        self.imgs_path = imgs_path   # 读取读片的路径各图片名字,存储到xml标签文件中
        self.classes = ['Finger']

    # 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,...
    # writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中
    def search_all_classes(self, writer=False):
        # 读取每一个txt标签文件,取出每个目标的标注信息
        all_names = set()
        txts = os.listdir(self.txts_path)
        # 使用列表生成式过滤出只有后缀名为txt的标签文件
        txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
        print(len(txts), txts)
        # 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt']
        for txt in txts:
            txt_file = os.path.join(self.txts_path, txt)
            with open(txt_file, 'r') as f:
                objects = f.readlines()
                for object in objects:
                    object = object.strip().split(' ')
                    print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']
                    all_names.add(int(object[0]))
            # print(objects)  # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n']

        print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts))

        return list(all_names)

    def yolo2voc(self):
        # 创建一个保存xml标签文件的文件夹
        if not os.path.exists(self.xmls_path):
            os.mkdir(self.xmls_path)

        # 把上面的两个循环改写成为一个循环:
        imgs = os.listdir(self.imgs_path)
        txts = os.listdir(self.txts_path)
        txts = [txt for txt in txts if not txt.split('.')[0] == "classes"]  # 过滤掉classes.txt文件
        print(txts)
        # 注意,这里保持图片的数量和标签txt文件数量相等,且要保证名字是一一对应的   (后面改进,通过判断txt文件名是否在imgs中即可)
        if len(imgs) == len(txts):   # 注意:./Annotation_txt 不要把classes.txt文件放进去
            map_imgs_txts = [(img, txt) for img, txt in zip(imgs, txts)]
            txts = [txt for txt in txts if txt.split('.')[-1] == 'txt']
            print(len(txts), txts)
            for img_name, txt_name in map_imgs_txts:
                # 读取图片的尺度信息
                print("读取图片:", img_name)
                img = cv2.imread(os.path.join(self.imgs_path, img_name))
                height_img, width_img, depth_img = img.shape
                print(height_img, width_img, depth_img)   # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度)

                # 获取标注文件txt中的标注信息
                all_objects = []
                txt_file = os.path.join(self.txts_path, txt_name)
                with open(txt_file, 'r') as f:
                    objects = f.readlines()
                    for object in objects:
                        object = object.strip().split(' ')
                        all_objects.append(object)
                        print(object)  # ['2', '0.506667', '0.553333', '0.490667', '0.658667']

                # 创建xml标签文件中的标签
                xmlBuilder = Document()
                # 创建annotation标签,也是根标签
                annotation = xmlBuilder.createElement("annotation")

                # 给标签annotation添加一个子标签
                xmlBuilder.appendChild(annotation)

                # 创建子标签folder
                folder = xmlBuilder.createElement("folder")
                # 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages
                folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1])  # 标签内存
                folder.appendChild(folderContent)  # 把内容存入标签
                annotation.appendChild(folder)   # 把存好内容的folder标签放到 annotation根标签下

                # 创建子标签filename
                filename = xmlBuilder.createElement("filename")
                # 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg
                filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg')  # 标签内容
                filename.appendChild(filenameContent)
                annotation.appendChild(filename)

                # 把图片的shape存入xml标签中
                size = xmlBuilder.createElement("size")
                # 给size标签创建子标签width
                width = xmlBuilder.createElement("width")  # size子标签width
                widthContent = xmlBuilder.createTextNode(str(width_img))
                width.appendChild(widthContent)
                size.appendChild(width)   # 把width添加为size的子标签
                # 给size标签创建子标签height
                height = xmlBuilder.createElement("height")  # size子标签height
                heightContent = xmlBuilder.createTextNode(str(height_img))  # xml标签中存入的内容都是字符串
                height.appendChild(heightContent)
                size.appendChild(height)  # 把width添加为size的子标签
                # 给size标签创建子标签depth
                depth = xmlBuilder.createElement("depth")  # size子标签width
                depthContent = xmlBuilder.createTextNode(str(depth_img))
                depth.appendChild(depthContent)
                size.appendChild(depth)  # 把width添加为size的子标签
                annotation.appendChild(size)   # 把size添加为annotation的子标签

                # 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标
                for object_info in all_objects:
                    # 开始创建标注目标的label信息的标签
                    object = xmlBuilder.createElement("object")  # 创建object标签
                    # 创建label类别标签
                    # 创建name标签
                    imgName = xmlBuilder.createElement("name")  # 创建name标签
                    imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])])
                    imgName.appendChild(imgNameContent)
                    object.appendChild(imgName)  # 把name添加为object的子标签

                    # 创建pose标签
                    pose = xmlBuilder.createElement("pose")
                    poseContent = xmlBuilder.createTextNode("Unspecified")
                    pose.appendChild(poseContent)
                    object.appendChild(pose)  # 把pose添加为object的标签

                    # 创建truncated标签
                    truncated = xmlBuilder.createElement("truncated")
                    truncatedContent = xmlBuilder.createTextNode("0")
                    truncated.appendChild(truncatedContent)
                    object.appendChild(truncated)

                    # 创建difficult标签
                    difficult = xmlBuilder.createElement("difficult")
                    difficultContent = xmlBuilder.createTextNode("0")
                    difficult.appendChild(difficultContent)
                    object.appendChild(difficult)

                    # 先转换一下坐标
                    # (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax)
                    x_center = float(object_info[1])*width_img + 1
                    y_center = float(object_info[2])*height_img + 1
                    xminVal = int(x_center - 0.5*float(object_info[3])*width_img)   # object_info列表中的元素都是字符串类型
                    yminVal = int(y_center - 0.5*float(object_info[4])*height_img)
                    xmaxVal = int(x_center + 0.5*float(object_info[3])*width_img)
                    ymaxVal = int(y_center + 0.5*float(object_info[4])*height_img)

                    # 创建bndbox标签(三级标签)
                    bndbox = xmlBuilder.createElement("bndbox")
                    # 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息
                    # 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标
                    # 1、创建xmin标签
                    xmin = xmlBuilder.createElement("xmin")  # 创建xmin标签(四级标签)
                    xminContent = xmlBuilder.createTextNode(str(xminVal))
                    xmin.appendChild(xminContent)
                    bndbox.appendChild(xmin)
                    # 2、创建ymin标签
                    ymin = xmlBuilder.createElement("ymin")  # 创建ymin标签(四级标签)
                    yminContent = xmlBuilder.createTextNode(str(yminVal))
                    ymin.appendChild(yminContent)
                    bndbox.appendChild(ymin)
                    # 3、创建xmax标签
                    xmax = xmlBuilder.createElement("xmax")  # 创建xmax标签(四级标签)
                    xmaxContent = xmlBuilder.createTextNode(str(xmaxVal))
                    xmax.appendChild(xmaxContent)
                    bndbox.appendChild(xmax)
                    # 4、创建ymax标签
                    ymax = xmlBuilder.createElement("ymax")  # 创建ymax标签(四级标签)
                    ymaxContent = xmlBuilder.createTextNode(str(ymaxVal))
                    ymax.appendChild(ymaxContent)
                    bndbox.appendChild(ymax)

                    object.appendChild(bndbox)
                    annotation.appendChild(object)  # 把object添加为annotation的子标签
                f = open(os.path.join(self.xmls_path, txt_name.split('.')[0]+'.xml'), 'w')
                xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8')
                f.close()

if __name__ == '__main__':
    # 把yolo的txt标签文件转化为voc格式的xml标签文件
    # yolo格式txt标签文件相对路径
    txts_path1 = './YOLOMarkFile/labels'
    # 转化为voc格式xml标签文件存储的相对路径
    xmls_path1 = './VOCFile/Annotations'
    # 存放图片的相对路径
    imgs_path1 = './YOLOMarkFile/images'

    yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1)
    labels = yolo2voc_obj1.search_all_classes()
    print('labels: ', labels)
    yolo2voc_obj1.yolo2voc()

在这里插入图片描述

e)将转换得到的xml文件放到YOLOX/dataset/FingerDetect/Annotations文件夹下。
在这里插入图片描述xml文件如下格式
在这里插入图片描述

4.2数据分组
需要将标注图片分为train / val两种用途。作引导的txt文件内容格式:
在这里插入图片描述
可以新建一个py文件,用random函数将它们按照一定比例分配成两组。当然,还是推荐手动写txt 。引导文件格式简单,Ctrl+C/V/H比码代码快多了… 😃

最终生成两个txt文件:
在这里插入图片描述

4.3 训练参数配置

由于YOLOX使用的是COCO数据集,我们使用自己数据集时,需要反复修改YOLOX的多个配置文件。笔者首次复现时遇到了各种古怪的问题,最后发现几乎都是因为配置文件没有完全修改完毕。

修改前,建议将YOLOX/tools中的train.pyYOLOX/exps/example/yolox_voc中的yolox_voc_x.py文件拷贝到根目录下,分别作为训练入口和训练指引文件。

4.3.1 修改训练集信息
修改YOLOX/yolox_voc_x.py中的VOCDetection

Before:
在这里插入图片描述After:

yolox/data/datasets/
修改YOLOX/yolox/data/datasets/voc.py中的VOCDection函数(注意,以后使用COCO,还需要改回去~):

Before:
在这里插入图片描述

After:
在这里插入图片描述
4.3.2 修改验证集信息

修改YOLOX/yolox_voc_x.py中的get_eval_loader,去除COCO信息。

Before:
在这里插入图片描述

After:
在这里插入图片描述

4.3.3 修改网络结构
修改YOLOS/exps/default/中的yolox_s.py,可以适当调整网络深度和宽度信息。但是需要和YOLOX/yolox_voc_x.py以及YOLOX/yolox/exp/yolox_base.py 保持一致。
在这里插入图片描述

4.3.4 修改数据集类型
a) 修改YOLOX/yolox/data/datasets/中的voc_classes.py ,将本来的类型去掉,换成我们自己的类型。

在这里插入图片描述b) 修改YOLOX/yolox_voc_x.pyYOLOX/yolox/exp/yolox_base.py中的参数self.num_classes,与我们的类型匹配。YOLOX默认也是80个标注类型。

Before:
在这里插入图片描述
After:
在这里插入图片描述

4.3.5 修改验证过程配置
a) 修改YOLOX/yolox/data/datasets/voc.py中的 _do_python_eval函数,去除COCO信息。

Before:
在这里插入图片描述在这里插入图片描述在这里插入图片描述

After:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
4.4 修改训练文件
打开YOLOX/train.py,修改make_parser中的主要参数。
在这里插入图片描述
主要修改项:
权重 - 选择预训练模型,连接到下载的yolox_x权重

parser.add_argument("-c", "--ckpt", default='weights/yolox_x.pth', type=str, help="checkpoint file")

模型 - 选择使用的模型,连接到yolox_x.py文件

parser.add_argument("-n", "--name", type=str, default='datasets/FingerDetect/yolox_x.py', help="model name")

数据集 - 连接到引导文件YOLOX/yolox_voc_s_Finger.py

parser.add_argument("-f","--exp_file",default='./yolox_voc_s_Finger.py',type=str,help="plz input your experiment description file",)

Batch大小 - 训练块大小。过大可能会CUDA报错空间不足

parser.add_argument("-b", "--batch-size", type=int, default=8, help="batch size")

EPoch数量 - 最大训练次数。
需要在YOLOX/yolox/exp/yolox_base.py中修改。
在这里插入图片描述
4.5 开始训练

cd YOLOX
python3 train.py

5. 训练结果

训练结果默认放置在YOLOX/YOLOX_outputs/FingerDetect文件夹下。
在这里插入图片描述

6.结果展示

可以用tensorboard可视化结果,例如以5.训练结果为例:

cd YOLOX/YOLOX_outputs
tensorboard --logdir=FingerDetect

在命令行输出的以下结果中右键打开localhost连接
在这里插入图片描述

即可在浏览器中看到tensorboard的显示结果
在这里插入图片描述

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值