跑通基于YOLOv5的旋转框目标检测

首先介绍一下如何制作带有角度的YOLOv5数据集标注。

先去这个网站下标注软件代码GitHub - cgvict/roLabelImg: Label Rotated Rect On Images for training

随后为其创建独立的虚拟环境

conda create -n rolabel36 python=3.6

激活环境

conda activate rolabel36

安装对应的依赖包

pip install pyqt5-tools
pip install lxml
pyrcc5 -o resources.py resources.qrc 

在项目根目录下运行打开标注软件的程序

python roLabelImg.py

选中需要标注的文件夹即可对数据集进行标注,标注示例如下图所示。

然而软件标注的格式是xml格式,网络训练需要的标注格式是txt文件,因此需要把xml文件转换为txt格式。以下代码可以实现这个转换,只需要把XMLDIR和OUTDIR改为自己的路径即可。

import os
import os.path as osp
import math

BASEDIR = osp.dirname(osp.abspath(__file__))    #获取的当前执行脚本的完整路径
XMLDIR = osp.join(BASEDIR, r'C:\Users\BJUT\Desktop\label')
OUTDIR = osp.join(BASEDIR, r'C:\Users\BJUT\Desktop\labels')
if not osp.exists(OUTDIR):
    os.makedirs(OUTDIR)
# pi=3.1415926
xmlnames = [i for i in os.listdir(XMLDIR) if i.endswith('.xml')]
# print(names)
pi = 3.1415926


# 转换成四点坐标
def convert(cx, cy, w, h, a):
    if a >= pi:
        a -= pi
    # 计算斜径半长
    l = math.sqrt(w ** 2 + h ** 2) / 2
    # 计算初始矩形角度
    a0 = math.atan(h / w)
    # 旋转,计算旋转角
    # 右上角点 ↗
    a1 = a0 + a
    x1 = cx + l * math.cos(a1)
    y1 = cy + l * math.sin(a1)
    # 右下角点 ↘
    a2 = -a0 + a
    x2 = cx + l * math.cos(a2)
    y2 = cy + l * math.sin(a2)
    # 左下角点 ↙
    a3 = a1 + pi
    x3 = cx + l * math.cos(a3)
    y3 = cy + l * math.sin(a3)
    # 左上角点 ↖
    a4 = a2 + pi
    x4 = cx + l * math.cos(a4)
    y4 = cy + l * math.sin(a4)
    return [x1, y1, x2, y2, x3, y3, x4, y4]


# 点关于直线对称
for xmlname in xmlnames:
    cx = []
    cy = []
    w = []
    h = []
    angle = []
    name = []

    txtname = xmlname.split('.')[-2] + '.txt'       #从右往左数,右边第一个是-1

    with open(osp.join(OUTDIR, txtname), 'w') as fp:
        fp.write('')                                   #向文件中写入东西
    with open(osp.join(XMLDIR, xmlname), 'rb') as fp:
        lines = fp.readlines()     #依次读取每行
    for line in lines:
        # print(line, end='')
        if line.strip().startswith(b'<width>'):   #strip()为去掉首尾空格
            img_width = eval(line.strip().strip(b'<width>').strip(b'</width>')) #eval函数的作用是获取返回值
            # print(img_width)
        if line.strip().startswith(b'<height>'):
            img_height = eval(line.strip().strip(b'<height>').strip(b'</height>'))
            # print(img_height)
        if line.strip().startswith(b'<depth>'):
            img_depth = eval(line.strip().strip(b'<depth>').strip(b'</depth>'))
            # print(img_depth)
        if line.strip().startswith(b'<cx>'):
            cx.append(eval(line.strip().strip(b'<cx>').strip(b'</cx>')))  #空列表里加上读取的值
            # print(cx)
        if line.strip().startswith(b'<cy>'):
            cy.append(eval(line.strip().strip(b'<cy>').strip(b'</cy>')))
            # print(cy)
        if line.strip().startswith(b'<w>'):
            w.append(eval(line.strip().strip(b'<w>').strip(b'</w>')))
            # print(w)
        if line.strip().startswith(b'<h>'):
            h.append(eval(line.strip().strip(b'<h>').strip(b'</h>')))
            # print(h)
        if line.strip().startswith(b'<angle>'):
            angle.append(eval(line.strip().strip(b'<angle>').strip(b'</angle>')))
        if line.strip().startswith(b'<name>'):
            name.append(line.strip().strip(b'<name>').strip(b'</name>').decode('utf-8'))
            # print(angle)
    #with open(osp.join(OUTDIR, txtname), 'a') as fp:
        #fp.write("imagesource:GoogleEarth")
        #fp.write('\n')
        #fp.write("gsd:0.146343590398")
        #fp.write('\n')
    for i in range(len(cx)):
        cls0 = 0.0
        cx_i = cx[i]
        cy_i = cy[i]
        w_i = w[i]
        h_i = h[i]
        a_i = angle[i]
        object = name[i]

        x0, y0, x1, y1, x2, y2, x3, y3 = convert(cx_i, cy_i, w_i, h_i, a_i)

        put_str = ' '.join(
            [str(x0), str(y0), str(x1), str(y1), str(x2), str(y2), str(x3), str(y3), str(object), str(0)])
        #put_str = ' '.join([str(cx_i), str(cy_i), str(w_i), str(h_i), str(a_i)])
        with open(osp.join(OUTDIR, txtname), 'a') as fp:
            fp.write(put_str)

            fp.write('\n')
    print(xmlname, 'to', txtname, 'done.')

接着,将数据集按如下方式进行分类(数据集的名字可以换,但是对应的图像和标注的名字必须是images和labelTxt)。

接下来我们就可以训练模型了。

训练模型的第一步当然是下载代码。

https://github.com/hukaixuan19970627/yolov5_obb

为其创建独特的虚拟环境。

conda create -n YOLOv5 python=3.8

进入所创建的虚拟环境。

conda activate YOLOv5

在安装依赖包之前,需要安装与pytorch相对应的cuda版本,此项目使用的pytorch为1.7.0。去pytorch官网查看对应的CUDA版本(pytorch官网地址:PyTorch)。如下图所示,项目需要安装CUDA10.2。

通过命令行输入nvidia-smi查看自己的显卡驱动版本以及支持的最大CUDA版本,下图第一行就显示了这些信息,可以看到,最大支持CUDA12.2,因此可以放心地安装cuda10.2。

进入下面这个网页下载CUDA10.2

CUDA Toolkit 10.2 Download | NVIDIA Developer

可以通过命名行下载,也可以自行下载。下载后运行固定命令安装,安装命令如下:

sudo sh cuda_10.2.89_440.33.01_linux.run

安装好cuda以后,再安装pytorch,进入pytorch官网(Previous PyTorch Versions | PyTorch),找到对应的pytorch版本,运行相应的指令。例如我们需要1.7.0版本的pytorch,需要运行在创建的虚拟环境下运行这条指令:

conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=10.2 -c pytorch

等待安装完成即可,如果安装过程很慢,考虑换源,还有一点,一定要关掉VPN。

安装完pytorch和CUDA以后呢,我们可以再检查一下pytorch和cuda的版本是否匹配!(要在刚刚创建的YOLOv5虚拟环境内)首先在终端输入python,接着一一输入以下三条命令,如果返回的是True说明我们的pytorch和cuda的版本匹配,如果返回的是False,说明两者不匹配。

import torch
print(torch.__version__)
print(torch.cuda.is_available())

接下来就可以安装依赖包了(必须在YOLOv5环境内)

pip install -r requirements.txt

等待安装完成即可。

接下来按着yolov5_obb作者的步骤走。

cd utils/nms_rotated
python setup.py develop
cd yolov5_obb/DOTA_devkit
sudo apt-get install swig
swig -c++ -python polyiou.i
python setup.py build_ext --inplace

中途我也遇到过问题,但一一百度,都能解决。

ok,下一步是训练模型!这部分和官方的YOLOv5步骤差不多。

1、在/data/scripts目录下创建自己的data,复制粘贴coco.yaml并重新命名为my_data.yaml。里面数据集的地址改成自己数据集的地址。我的数据集地址是这样:

2、编辑train.py :workers最好设置为0 ,weights设为空(如果嫌浪费时间,可以加入预训练模型),cfg是关于模型参数的一些设置(在models/hub下可以看到一些已经给的参数设置,复制yolov5s.yaml路径到cfg中),data为指定数据集(也就是刚刚创建的my_data.yaml),对应的default填其路径,在这里应为data/my_data.yaml。epochs代表训练的轮数,自己设置即可。

3、运行train.py,等待训练完成即可。

  • 16
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值