基于mmrotate旋转目标检测框架的训练

基本思想:需要检测带角度的目标。

前期因为配环境踩了很多坑,发现是mmdet和mmcv-full版本对应的问题

前言:

配置环境ubuntu20.04,GPU 3050RTX,mmrotate-0.3.2

一、搭建环境

拉取源码,因为前期踩了很多坑,没有拉取最新的,这里以0.3.2版本为例。

# 拉取源码
git clone https://github.com/open-mmlab/mmrotate/releases/tag/v0.3.2

cd mmrotate-0.3.2

# 创建conda环境
conda create -n mmrotate python=3.8

conda activate mmrotate

pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

pip install mmcv-full==1.6.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html

pip install mmdet==2.25

pip install -r requirements/build.txt

pip install -v -e .

 在安装mmdet时候一定注意版本问题,要不后续运行将出现版本不匹配的报错问题。

AssertionError: MMCV==1.7.1 is used but incompatible. Please install mmcv>=2.0.0rc4, <2.1.0.

这里给出mmcv/mmmdet/mmdet3d版本对照表,参考此处 AssertionError: MMCV==1.7.1 is used but incompatible. Please install mmcv>=2.0.0rc4, <2.1.0._MFT小白的博客-CSDN博客

MMDetection3D versionMMDetection versionMMSegmentation versionMMCV、 version
mastermmdet>=2.24.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.5.2, <=1.7.0
v1.0.0rc4mmdet>=2.24.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.5.2, <=1.7.0
v1.0.0rc3mmdet>=2.24.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.4.8, <=1.6.0
v1.0.0rc2mmdet>=2.24.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.4.8, <=1.6.0
v1.0.0rc1mmdet>=2.19.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.4.8, <=1.5.0
v1.0.0rc0mmdet>=2.19.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.3.17, <=1.5.0
0.18.1mmdet>=2.19.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.3.17, <=1.5.0
0.18.0mmdet>=2.19.0, <=3.0.0mmseg>=0.20.0, <=1.0.0mmcv-full>=1.3.17, <=1.5.0
0.17.3mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.17.2mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.17.1mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.17.0mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.16.0mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.15.0mmdet>=2.14.0, <=3.0.0mmseg>=0.14.1, <=1.0.0mmcv-full>=1.3.8, <=1.4.0
0.14.0mmdet>=2.10.0, <=2.11.0mmseg==0.14.0mmcv-full>=1.3.1, <=1.4.0
0.13.0mmdet>=2.10.0, <=2.11.0Not requiredmmcv-full>=1.2.4, <=1.4.0
0.12.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.2.4, <=1.4.0
0.11.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.2.4, <=1.3.0
0.10.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.2.4, <=1.3.0
0.9.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.2.4, <=1.3.0
0.8.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.1.5, <=1.3.0
0.7.0mmdet>=2.5.0, <=2.11.0Not requiredmmcv-full>=1.1.5, <=1.3.0
0.6.0mmdet>=2.4.0, <=2.11.0Not requiredmmcv-full>=1.1.3, <=1.2.0
0.5.02.3.0Not requiredmmcv-full==1.0.5

二、测试环境

2.1 下载模型
(test) ubuntu@ubuntu:~/mmrotate-0.3.2$ wget https://download.openmmlab.com/mmrotate/v0.1.0/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth
2.2 运行测试程序
(test) ubuntu@ubuntu:~/mmrotate-0.3.2$ python demo/image_demo.py demo/demo.jpg configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py model/r3det_kfiou_ln_r50_fpn_1x_dota_oc-8e7f049d.pth --out-file test.jpg

环境搭建完毕!

三、创建数据集

如铭牌中的条形码为例,推荐使用labelme工具制作,也可labelImg,但是需要后期用脚本转到json格式。

标注的顺序为顺时针方向,起始点为左上角。依次为 x1, y1, x2, y2, x3, y3, x4, y4

标注方式参考官网解释  https://captain-whu.github.io/DOTA/index.html

 这里需要用到数据扩充,可参考此处。python实现xml及json文件角度,亮度数据集增强_MFT小白的博客-CSDN博客

四、训练模型

4.1 新建文件夹

训练会一直用到

(mmrotate) ubuntu@ubuntu:~/mmrotate-0.3.2$ mkdir -p kfiouDataSets

datasets文件夹包含原图jpg和json文件,注意备份。

ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ tree -L 2
.
├── datasets
├── 01.jpg
├── 01.json
├── ...
├── ...
├── 994.jpg
└── 994.json
4.2 训练集、验证集、测试集

这里需要提前新建好各个数据集文件夹,手动按照比例存放。这里8:2:2

ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ tree -L 1
.
├── datasets      # 数据集源文件夹
├── testDataset   # 测试集包含jpg和json文件 可以为空
├── trainDataset  # 训练集包含jpg和json文件
└── valDataset    # 验证集包含jpg和json文件

4 directories, 1 file

4.3 DOTA数据集

参考脚本  labelme2dota.py 会生对应json的txt文件,结构如下:

x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
# category: 标签名字
# difficult:表示标签检测的难易程度 (1表示困难,0表示不困难)

参考此处:https://captain-whu.github.io/DOTA/index.html

注意:网络输入图片格式需要png格式,还要将jpg转一下,直接把labeme2dota.py修改一下,上代码。

import json
import os
from glob import glob
import argparse
import numpy as np
import shutil
from PIL import Image
import cv2
 
# convert labelme json to DOTA txt format
# convert DOTA json to lableme txt format
def custombasename(fullname):
    return os.path.basename(os.path.splitext(fullname)[0])
 
 
def order_points_new(pts):  # clock -https://zhuanlan.zhihu.com/p/10643062
    # sort the points based on their x-coordinates
    xSorted = pts[np.argsort(pts[:, 0]), :]
 
    # grab the left-most and right-most points from the sorted
    # x-roodinate points
    leftMost = xSorted[:2, :]
    rightMost = xSorted[2:, :]
    if leftMost[0, 1] != leftMost[1, 1]:
        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
    else:
        leftMost = leftMost[np.argsort(leftMost[:, 0])[::-1], :]
    (tl, bl) = leftMost
    if rightMost[0, 1] != rightMost[1, 1]:
        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
    else:
        rightMost = rightMost[np.argsort(rightMost[:, 0])[::-1], :]
    (tr, br) = rightMost
    # return the coordinates in top-left, top-right,
    # bottom-right, and bottom-left order
    return np.array([tl, tr, br, bl], dtype="float32")
 
 
parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--input_dir', default=r'kfiouDataSets/trainDataset', type=str,
                    help='input annotated directory')
parser.add_argument('--output_images', default=r'kfiouDataSets/datasets/images', type=str,
                    help='input annotated directory')
parser.add_argument('--output_dir', default=r'kfiouDataSets/train_annfile', type=str, help='output dataset directory')
parser.add_argument('--verify_dir', default=r'kfiouDataSets/datasets/verify', type=str,
                    help='verify dataset directory')
parser.add_argument('--verify', default=True, type=bool, help='verify')
parser.add_argument('--labels', default=r'kfiouDataSets/labels.txt', type=str, help='labels annotated directory')
args = parser.parse_args()
 
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
if not os.path.exists(args.output_images):
    os.makedirs(args.output_images)
print('Creating dataset:', args.output_dir)
 
file_list = glob(os.path.join(args.input_dir, ".".join(["*", "json"])))
 
for i in range(len(file_list)):
    with open(file_list[i]) as f:
        label_str = f.read()
        label_dict = json.loads(label_str)  # json文件读入dict
 
        # 输出 txt 文件的路径
        out_file = os.path.join(args.output_dir, ".".join([custombasename(file_list[i]), 'txt']))
        # 写入 poly 四点坐标 和 label
        fout = open(out_file, 'w')
        out_str = ''
        # np.array(box, dtype="int")
        for shape_dict in label_dict['shapes']:
            points = shape_dict['points']
            item_points = []
            for p in points:
                item_points.append([p[0], p[1]])
            item_points = order_points_new(np.array(item_points, dtype="float"))
            for p in item_points.tolist():
                out_str += (str(p[0]) + ' ' + str(p[1]) + ' ')
            out_str += shape_dict['label'] + ' 0\n'
        fout.write(out_str)
        fout.close()
    print('%d/%d' % (i + 1, len(file_list)))
    print("labelme2dota...")
if args.verify:
    if not os.path.exists(args.verify_dir):
        os.makedirs(args.verify_dir)
    txt_list = glob(os.path.join(args.output_dir, ".".join(["*", "txt"])))
    for i in range(len(txt_list)):
        (filepath, tempfilename) = os.path.split(txt_list[i])
        (filename, extension) = os.path.splitext(tempfilename)
        sourcePath = None
        image_filename = None
        if os.path.exists(os.path.join(args.input_dir, ".".join([filename, "jpg"]))):
            sourcePath = os.path.join(args.input_dir, ".".join([filename, "jpg"]))            
            image_filename = ".".join([filename, "png"])
            print(filename)
            print(image_filename)
        elif os.path.exists(os.path.join(args.input_dir, ".".join([filename, "png"]))):
            sourcePath = os.path.join(args.input_dir, ".".join([filename, "png"]))
            image_filename = ".".join([filename, "png"])
        if sourcePath is None:
            print("check photo type")
            continue
        targetPath = os.path.join(args.verify_dir,image_filename)
        targetpng = os.path.join(args.output_images, image_filename)
        shutil.copy(sourcePath, targetPath)
        shutil.copy(sourcePath, targetpng)
        img = Image.open(sourcePath)
        imgSize = img.size  # 大小/尺寸
        w = img.width  # 图片的宽
        h = img.height  # 图片的高
 
        data = {}
        data['imagePath'] = image_filename
        data['flags'] = {}
        data['imageWidth'] = w
        data['imageHeight'] = h
        data['imageData'] = None
        data['version'] = "5.0.1"
        data["shapes"] = []
 
        with open(txt_list[i]) as f:
            label_str = f.readlines()
            for label_item in label_str:
                line_char = label_item.split("\n")[0].split(' ')
                points = [[eval(line_char[0]), eval(line_char[1])], [eval(line_char[2]), eval(line_char[3])],
                          [eval(line_char[4]), eval(line_char[5])], [eval(line_char[6]), eval(line_char[7])]]
                itemData = {'points': []}
                itemData['points'].extend(points)
                itemData["flag"] = {}
                itemData["group_id"] = None
                itemData["shape_type"] = "polygon"
                itemData["label"] = line_char[-2]
                data["shapes"].append(itemData)
 
            jsonName = ".".join([filename, "json"])
            jsonPath = os.path.join(args.verify_dir, jsonName)
            with open(jsonPath, "w") as f:
                json.dump(data, f)
            print(jsonName)
            print("dota2labelme...")
--output_images # 输出转换png格式所有图片
--output_dir    # 输出当前训练集txt格式
--verify_dir    # 输入数据集源文件 png + json 文件 
--verify        # 默认 True,图片保存才有效
--labels        # 提前新建labels.txt 保存标签



ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ cat labels.txt 
barcode
4.3 文件夹结构

大致展示一下文件

ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ tree -L 1
.
├── datasets        # 数据集源文件 jpg + json
├── labels.txt      # 标签文件
├── source          # png 数据集
├── test_annfile    # 测试集txt
├── testDataset     # 测试集
├── train_annfile   # 训练集txt
├── trainDataset    # 训练集
├── val_annfile     # 验证集txt
└── valDataset      # 验证集

8 directories, 1 file
4.4 数据转换
ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ python3 labelme2dota.py --input_dir kfiouDataSets/trainDataset/ --output_dir kfiouDataSets/train_annfile --verify True --labels kfiouDataSets/labels.txt
ubuntu@ubuntu:~/mmrotate-0.3.2/kfiouDataSets$ python3 labelme2dota.py --input_dir kfiouDataSets/valDataset/ --output_dir kfiouDataSets/val_annfile --verify false --labels kfiouDataSets/labels.txt
4.5 修改配置文件
home/ubuntu/mmrotate-0.3.2/configs/_base_/datasets/dotav1.py

变更数据集的根目录和训练集 测试集 验证集的目录

# dataset settings
dataset_type = 'DOTADataset'
data_root = 'kfiouDataSets/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='RResize', img_scale=(1024, 1024)),
    dict(type='RRandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1024, 1024),
        flip=False,
        transforms=[
            dict(type='RResize'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img'])
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'train_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'val_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'test_annfile/',
        img_prefix=data_root + 'source/',
        pipeline=test_pipeline))
4.6 修改一下标签
ubuntu@ubuntu:~/mmrotate-0.3.2$ sudo gedit mmrotate/datasets/dota.py
class DOTADataset(CustomDataset):
    """DOTA dataset for detection.

    Args:
        ann_file (str): Annotation file path.
        pipeline (list[dict]): Processing pipeline.
        version (str, optional): Angle representations. Defaults to 'oc'.
        difficulty (bool, optional): The difficulty threshold of GT.
    """

    # CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
    #            'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
    #            'basketball-court', 'storage-tank', 'soccer-ball-field',
    #            'roundabout', 'harbor', 'swimming-pool', 'helicopter')
    
    CLASSES = ('barcode',) # 必须要带上,

注意:这里还要修改miniconda环境里面对应名称的映射文件。作者通过FSearch查找路径为

/home/ubuntu/miniconda3/envs/mmrotate/lib/python3.8/site-packages/mmrotate/datasets

 4.7 修改标签数量
ubuntu@ubuntu:~/mmrotate-0.3.2$ sudo gedit configs/r3det/r3det_r50_fpn_1x_dota_oc.py
ubuntu@ubuntu:~/mmrotate-0.3.2$ sudo gedit configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py

num_classes=1  # 
4.8 新建存放训练日志和权重的文件夹
ubuntu@ubuntu:~/mmrotate-0.3.2$ mkdir -p run
4.9 开始训练
(mmrotate) ubuntu@ubuntu:~/mmrotate-0.3.2$ python3 tools/train.py configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py --work-dir=run
2023-08-07 15:18:24,557 - mmrotate - INFO - workflow: [('train', 1)], max: 50 epochs
2023-08-07 15:18:24,557 - mmrotate - INFO - Checkpoints will be saved to /home/ubuntu/mmrotate-0.3.2/run by HardDiskBackend.
/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
2023-08-07 15:18:57,113 - mmrotate - INFO - Epoch [1][50/364]	lr: 9.967e-04, eta: 3:16:54, time: 0.651, data_time: 0.047, memory: 2801, s0.loss_cls: 1.1590, s0.loss_bbox: 6.6705, sr0.loss_cls: 1.1172, sr0.loss_bbox: 6.5282, loss: 15.4748, grad_norm: 16.5861
2023-08-07 15:19:27,457 - mmrotate - INFO - Epoch [1][100/364]	lr: 1.163e-03, eta: 3:09:43, time: 0.607, data_time: 0.003, memory: 2801, s0.loss_cls: 0.9835, s0.loss_bbox: 6.0528, sr0.loss_cls: 0.5768, sr0.loss_bbox: 6.0088, loss: 13.6219, grad_norm: 28.2796
2023-08-07 15:19:57,906 - mmrotate - INFO - Epoch [1][150/364]	lr: 1.330e-03, eta: 3:07:11, time: 0.609, data_time: 0.003, memory: 2801, s0.loss_cls: 0.4865, s0.loss_bbox: 5.8674, sr0.loss_cls: 0.1423, sr0.loss_bbox: 5.9103, loss: 12.4065, grad_norm: 24.5735
2023-08-07 15:20:28,581 - mmrotate - INFO - Epoch [1][200/364]	lr: 1.497e-03, eta: 3:06:01, time: 0.613, data_time: 0.004, memory: 2801, s0.loss_cls: 0.3284, s0.loss_bbox: 5.7260, sr0.loss_cls: 0.0881, sr0.loss_bbox: 5.8602, loss: 12.0028, grad_norm: 17.1326
2023-08-07 15:20:59,550 - mmrotate - INFO - Epoch [1][250/364]	lr: 1.663e-03, eta: 3:05:27, time: 0.619, data_time: 0.004, memory: 2801, s0.loss_cls: 0.3540, s0.loss_bbox: 5.7076, sr0.loss_cls: 0.0388, sr0.loss_bbox: 5.8367, loss: 11.9372, grad_norm: 13.2928
2023-08-07 15:21:30,239 - mmrotate - INFO - Epoch [1][300/364]	lr: 1.830e-03, eta: 3:04:38, time: 0.614, data_time: 0.004, memory: 2801, s0.loss_cls: 0.6266, s0.loss_bbox: 5.6433, sr0.loss_cls: 0.0734, sr0.loss_bbox: 5.7544, loss: 12.0977, grad_norm: 13.4350
2023-08-07 15:22:00,668 - mmrotate - INFO - Epoch [1][350/364]	lr: 1.997e-03, eta: 3:03:41, time: 0.609, data_time: 0.004, memory: 2801, s0.loss_cls: 0.3442, s0.loss_bbox: 5.6681, sr0.loss_cls: 0.0265, sr0.loss_bbox: 5.8284, loss: 11.8672, grad_norm: 10.0150
2023-08-07 15:22:09,234 - mmrotate - INFO - Saving checkpoint at 1 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 182/182, 7.7 task/s, elapsed: 23s, ETA:     0s2023-08-07 15:22:35,882 - mmrotate - INFO - 
+---------+-----+------+--------+-------+
| class   | gts | dets | recall | ap    |
+---------+-----+------+--------+-------+
| barcode | 182 | 189  | 1.000  | 1.000 |
+---------+-----+------+--------+-------+
| mAP     |     |      |        | 1.000 |
+---------+-----+------+--------+-------+

已经跑起来了

4.10 测试模型
(mmrotate) ubuntu@ubuntu:~/mmrotate-0.3.2$ python demo/image_demo.py demo/80.png configs/kfiou/r3det_kfiou_ln_r50_fpn_1x_dota_oc.py run/epoch_50.pth --out-file demo/test.jpg

五、模型转换

部署后续添加。。。

参考文章:

1. mmrotate旋转目标检测框架的学习与使用_LinuxMelo的博客-CSDN博客

2. 基于mmrotate的旋转目标检测入门详解_Orange-_-的博客-CSDN博客

3. Prerequisites — mmrotate documentation

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值