Windows10使用MMrotate(初学),并训练自己的数据集

安装环境

        Windows10安装mmrotate的步骤其实和linux一样的

        首先检查自己所用的显卡以及对应的CUDA版本

        我这里使用的是NVIDIA 3080TI 安装的CUDA版本是11.0

        整个的安装过程其实mmrotate官网已经给出了,但是安装中难免会有些小问题,所以这里也记录一下。

        创建一个虚拟环境并激活。该环境就是以后用来运行mmrotate的了,这里安装的是python3.8。

conda create -n mmrotate python=3.8

conda activate mmrotate

        因为我没有在创建环境的时候安装pytorch和cudatoolkit,所以要单独安装一下。这里我的CUDA版本是11.0,安装cudatoolkit11.1没有出现问题。之前在一台Ubantu的机器上安装了CUDA11.4,使用cudatoolkit11.3也没有问题。

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

        之后也是按照官网的安装步骤,但是有一点mmcv-full请安装1.6.0,因为我使用默认的命令会安装1.6.1运行时会报错。

pip install openmim

mim install mmcv-full==1.6.0

mim install mmdet

        然后就是下载mmrotate的源代码,使用git命令跟网站下载压缩包都一样,反正后边命令切换到mmrotate源码的路径下就行了

git clone https://github.com/open-mmlab/mmrotate.git

cd mmrotate

pip install -r requirements/build.txt

pip install -v -e .

        到这里应该都不会出现任何问题了,mmroate的运行环境也已经配置好。


数据集配置

        这一点其实很多文章也都有了介绍,大多数也是使用了dota数据集的格式。大部分人使用rolabelimg标注的数据格式转成dota的txt,其实本质就是VOC格式转DOTA格式,这里给一个HRSC转DOTA格式的。HRSC与rolableimg的标注格式其实都是VOC,只不过细节略有不同。

import os
import sys
import json
import os.path as osp
import numpy as np
import xmltodict
from tqdm import tqdm

sys.path.append("..")
from dota_poly2rbox import rbox2poly_single


def parse_ann_info(objects):
    bboxes, labels, bboxes_ignore, labels_ignore = [], [], [], []
    # only one annotation
    if type(objects) != list:
        objects = [objects]
    for obj in objects:
        if obj['difficult'] == '0':
            bbox = float(obj['mbox_cx']), float(obj['mbox_cy']), float(
                obj['mbox_w']), float(obj['mbox_h']), float(obj['mbox_ang'])
            label = 'ship'
            bboxes.append(bbox)
            labels.append(label)
        elif obj['difficult'] == '1':
            bbox = float(obj['mbox_cx']), float(obj['mbox_cy']), float(
                obj['mbox_w']), float(obj['mbox_h']), float(obj['mbox_ang'])
            label = 'ship'
            bboxes_ignore.append(bbox)
            labels_ignore.append(label)
    return bboxes, labels, bboxes_ignore, labels_ignore


def ann_to_txt(ann):
    out_str = ''
    for bbox, label in zip(ann['bboxes'], ann['labels']):
        poly = rbox2poly_single(bbox)
        str_line = '{} {} {} {} {} {} {} {} {} {}\n'.format(
            poly[0], poly[1], poly[2], poly[3], poly[4], poly[5], poly[6], poly[7], label, '0')
        out_str += str_line
    for bbox, label in zip(ann['bboxes_ignore'], ann['labels_ignore']):
        poly = rbox2poly_single(bbox)
        str_line = '{} {} {} {} {} {} {} {} {} {}\n'.format(
            poly[0], poly[1], poly[2], poly[3], poly[4], poly[5], poly[6], poly[7], label, '1')
        out_str += str_line
    return out_str


def generate_txt_labels(root_path):
    img_path = osp.join(root_path, 'images')
    label_path = osp.join(root_path, 'annotations')
    label_txt_path = osp.join(root_path, 'labelTxt')
    if not osp.exists(label_txt_path):
        os.mkdir(label_txt_path)

    img_names = [osp.splitext(img_name.strip())[0] for img_name in os.listdir(img_path)]
    pbar = tqdm(img_names)
    for img_name in pbar:
        pbar.set_description("HRSC2016 Preparation...")

        label = osp.join(label_path, img_name + '.xml')
        label_txt = osp.join(label_txt_path, img_name + '.txt')
        f_label = open(label)
        data_dict = xmltodict.parse(f_label.read())
        data_dict = data_dict['HRSC_Image']
        f_label.close()
        label_txt_str = ''
        # with annotations
        if data_dict['HRSC_Objects']:
            objects = data_dict['HRSC_Objects']['HRSC_Object']
            bboxes, labels, bboxes_ignore, labels_ignore = parse_ann_info(
                objects)
            ann = dict(
                bboxes=bboxes,
                labels=labels,
                bboxes_ignore=bboxes_ignore,
                labels_ignore=labels_ignore)
            label_txt_str = ann_to_txt(ann)
        with open(label_txt, 'w') as f_txt:
            f_txt.write(label_txt_str)


if __name__ == '__main__':
    generate_txt_labels('/project/jmhan/data/HRSC2016/Train')
    generate_txt_labels('/project/jmhan/data/HRSC2016/Test')
    print('done!')

训练设置

        接下来到训练部分,总结了一下需要进行的操作:

  1. mmrotate需要输入的图像大小为方片,所以建议使用tools/data/dota/split/img_splits切割图像。修改--base-json参数所对应的路径。因为我习惯使用pycharm运行,不习惯使用命令行,所以我直接修改了default。运行该文件需要安装shapely模块。
    def add_parser(parser):
        """Add arguments."""
        parser.add_argument(
            '--base-json',
            type=str,
            default='./split_configs/ss_val.json',
            help='json config file for split images')
  2. 当然这之前肯定还要设置tools/data/dota/split/split_configs/ss_train.py 这些mmrotate都给出了模板,ms_train会得到多尺度的裁剪结果,ss_train是单尺度的。至少要裁剪train、val来完成最基本的训练和评估,该文件主要修改的是输入输出的路径信息。
    {
      "nproc": 10,
      "img_dirs": [
        "./HRSC2DOTA/train/images/"
      ],
      "ann_dirs": [
        "./HRSC2DOTA/train/labelTxt/"
      ],
      "sizes": [
        1024
      ],
      "gaps": [
        200
      ],
      "rates": [
        1.0
      ],
      "img_rate_thr": 0.6,
      "iof_thr": 0.7,
      "no_padding": false,
      "padding_value": [
        104,
        116,
        124
      ],
      "save_dir": "../split_ss_hrsc2dota/train/",
      "save_ext": ".png"
  3. 裁剪完成后,自然要进行训练参数等内容修改。训练文件存放在tools/train.py,需要修改config文件的路径,和存放训练权重等信息的路径(当然也可以没有)。还是因为上述的原因,我将config变成了可选参数以便于能够直接从pycharm运行。。
    def parse_args():
        parser = argparse.ArgumentParser(description='Train a detector')
        parser.add_argument('--config',default='../configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py', help='train config file path')
        parser.add_argument('--work-dir',default='model_weight', help='the dir to save logs and models')
  4. 上边的config选择了rotated_faster_rcnn作为旋转目标检测模型,该文件中仅有一处num_class需要修改,但像R3Det中会有两处,这点需要注意。该文件中还引用了三个文件,因此有部分参数还需要进入这些文件设置。
    _base_ = [
        '../_base_/datasets/dotav1.py', '../_base_/schedules/schedule_1x.py',
        '../_base_/default_runtime.py'
    ]
  5. 首先是dotaV1,这里主要注意要指明数据集的存放路径(也就是你切割后的路径),还有train和val分别对应的路径。官方给的路径中val和train默认的是相同的路径,需要注意啊。。。
    dataset_type = 'DOTADataset'
    data_root = 'data/dota/split_ss_hrsc2dota/'
    data = dict(
        samples_per_gpu=2,
        workers_per_gpu=2,
        train=dict(
            type=dataset_type,
            ann_file=data_root + 'train/annfiles/',
            img_prefix=data_root + 'train/images/',
            pipeline=train_pipeline),
        val=dict(
            type=dataset_type,
            ann_file=data_root + 'val/annfiles/',
            img_prefix=data_root + 'val/images/',
            pipeline=test_pipeline),
        test=dict(
            type=dataset_type,
            ann_file=data_root + 'val/annfiles/',
            img_prefix=data_root + 'val/images/',
            pipeline=test_pipeline))
  6. schedule_1x则是修改训练所使用的优化器等信息,这点根据个人来修改吧,没啥好说的。
  7. 还有一点,当使用自己的数据集时,要重新定义dataset类的信息,因为dataload的时候是通过这个类获取到的你的数据。该文件存放在mmrotate/datasets/dota.py,这里的mmrotate是mmrotate源码下的一个同名文件,例如我要设置仅有一类船只进行训练。
        # CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
        #            'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
        #            'basketball-court', 'storage-tank', 'soccer-ball-field',
        #            'roundabout', 'harbor', 'swimming-pool', 'helicopter')
        #
        # PALETTE = [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
        #            (138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255),
        #            (255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139),
        #            (255, 255, 0), (147, 116, 116), (0, 0, 255)]
        CLASSES = ('ship',)#只有一类时一定要添加逗号,之前我没有添加然后发现读取的类别是's','h','i','p'。。。。
    
        PALETTE = [(165, 42, 42),]
  8. 然后就是batch_size在mmrotate中没有使用该说法,但是可以在mmrotate/apis/train.py中找到,这里的mmrotate是mmrotate源码下的一个同名文件,实际的batch_size就是你使用的GPU数量*samples_per_gpu。
        train_dataloader_default_args = dict(
            samples_per_gpu=4,
            workers_per_gpu=2,
            # `num_gpus` will be ignored if distributed
            num_gpus=len(cfg.gpu_ids),
            dist=distributed,
            seed=cfg.seed,
            runner_type=runner_type,
            persistent_workers=False)
    

        总体来说mmrotate同样支持windows而且配置还挺简单,有点超乎我的预料。作为一个新兴的框架,无疑完成一些工程任务会方便很多。但是里边的一些参数文件需要互相调用,阅读起来有些复杂(对我来说),感觉小坑就是里边的各种路径设置,基本都是基于你的运行文件,比如运行train文件train所依赖的其他参数文件中的路径设置都是基于train的而并非参数文件本身。总之后边要熟练使用这个框架还有很长的一段路要走,这个博客应该也会继续更新。


参考

基于MMRotate训练自定义数据集 做旋转目标检测 2022-3-30_YD-阿三的博客-CSDN博客_旋转目标检测数据集

使用mmdetection训练和评估自定义数据集 - 知乎

MMRotate从零开始训练自己的数据集_江小白jlj的博客-CSDN博客

  • 3
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值