【Object Detection】【maskrcnn_benckmark】使用其他数据集训练

本文记录如何在 maskrcnn_benckmark 框架中使用其他数据集训练模型.

1. 我的数据集

建议将数据集转换成 COCO Format,可以使用 COCO API.
下面我将以 Airbus_Ship_Detection Dataset 为例说明如何添加到 maskrcnn_benchmark.
我已经将其转换为 coco style,其目录结构如下:

Ship_Detection_Airbus
|_ train
|  |_ images
|  |  |_ <im-1-name>.jpg
|  |  |_ ...
|  |  |_ <im-N-name>.jpg
|  |
|  |_ instance_train.json
|  
|_ val
   |_ images
   |  |_ <im-1-name>.jpg
   |  |_ ...
   |  |_ <im-N-name>.jpg
   |
   |_ instance_val.json

2. 添加软链接

(maskrcnn-benchmark)$ cd ~/SrcLibs/maskrcnn-benchmark/datasets
(maskrcnn-benchmark)$ ln -s /path/to/your/Ship_Detection_Airbus ship_detection_airbus

3. 添加数据集路径到项目

${ROOT_DIR}/maskrcnn_benchmark/config/paths_catalog.pyDatasetCatalog中添加下列内容:

class DatasetCatalog(object):
    THIS_DIR = Path(__file__).parent
    DATA_DIR = THIS_DIR.parents[2].joinpath("datasets")  # ${ROOT_DIR}/datasets
    DATASETS = {
        "coco_2017_train": {
            "img_dir": "coco/train2017",
            "ann_file": "coco/annotations/instances_train2017.json"
        },
        "coco_2017_val": {
            "img_dir": "coco/val2017",
            "ann_file": "coco/annotations/instances_val2017.json"
        },
        ...
        ...
        # ship_detection_airbus
        "ship_det_airbus_train": {                                  # add this line
            "img_dir": "ship_det_airbus/train/images",              # add this line
            "ann_file": "ship_det_airbus/train/train_hbb.json"      # add this line
        },                                                          # add this line
        "ship_det_airbus_val": {                                    # add this line
            "img_dir": "ship_det_airbus/val/images",                # add this line
            "ann_file": "ship_det_airbus/val/val_hbb.json"          # add this line
        },                                                          # add this line
    }

    @staticmethod
    def get(name):
        if "coco" in name:
            ...
            ...
        elif "voc" in name:
            ...
            ...
        elif "cityscapes" in name:
            ...
            ...
        # ship_detection_airbus
        elif "ship_det_airbus" in name:                              # add this line
            data_dir = DatasetCatalog.DATA_DIR                       # add this line
            attrs = DatasetCatalog.DATASETS[name]                    # add this line
            args = dict(                                             # add this line
                root=os.path.join(data_dir, attrs["img_dir"]),       # add this line
                ann_file=os.path.join(data_dir, attrs["ann_file"]),  # add this line
            )                                                        # add this line
            return dict(                                             # add this line
                factory="ShipDetAirbusDataset",                      # add this line
                args=args,                                           # add this line
            )                                                        # add this line
        raise RuntimeError("Dataset not available: {}".format(name))

4. 添加 Dataset 类封装

${ROOT_DIR}/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets 下添加 ship_det_airbus.py.
取巧的方法是复制然后重命名coco.py:

(maskrcnn-benchmark)$ cd ~/SrcLibs/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/
(maskrcnn-benchmark)$ cp coco.py ship_det_airbus.py

ship_detection_airbus.py 中的 COCODataset 替换为 ShipDetAirbusDataset,并将下列代码注释掉:
self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}

~/SrcLibs/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/__init__.py 中加入下列语句:

...
...
from .ship_detection_airbus import ShipDetAirbusDataset

__all__ = [
    "COCODataset",
    ...
    ...
    "ShipDetAirbusDataset",    # add this line
]

4.1 测试 data pipe

运行下面的脚本:

import matplotlib.pyplot as plt

from maskrcnn_benchmark.config.paths_catalog import DatasetCatalog
from maskrcnn_benchmark.data.build import build_dataset


ds_airbus = build_dataset(['ship_det_airbus_train'], transforms=None, dataset_catalog=DatasetCatalog)[0]

plt.imshow(ds_airbus[0][0])

在这里插入图片描述

5. 添加 evaluation 代码

(maskrcnn-benchmark)$ cd ${ROOT_DIR}/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/evaluation
(maskrcnn-benchmark)$ cp -R coco ship_det_airbus
(maskrcnn-benchmark)$ cd ship_det_airbus/
(maskrcnn-benchmark)$ mv coco_eval.py ship_det_airbus_eval.py
(maskrcnn-benchmark)$ rm -rf coco_wrapper.py abs_to_coco.py

修改 ${ROOT_DIR}/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/evaluation/ship_det_airbus/ship_det_airbus_eval.py中的

def do_coco_evaluation(
    ...
):

def do_ship_det_aibus_evaluation(
    ...
):

修改${ROOT_DIR}/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/evaluation/ship_det_airbus/__init__.py

from .ship_det_airbus_eval import do_ship_det_aibus_evaluation as do_orig_coco_evaluation
from maskrcnn_benchmark.data.datasets import ShipDetAirbusDataset


def ship_det_airbus_evaluation(
    dataset,
    predictions,
    output_folder,
    box_only,
    iou_types,
    expected_results,
    expected_results_sigma_tol,
):
    if isinstance(dataset, ShipDetAirbusDataset):
        return do_orig_coco_evaluation(
            dataset=dataset,
            predictions=predictions,
            box_only=box_only,
            output_folder=output_folder,
            iou_types=iou_types,
            expected_results=expected_results,
            expected_results_sigma_tol=expected_results_sigma_tol,
        )
    else:
        raise NotImplementedError(
            (
                "type(dataset)=%s" % type(dataset)
            )
        )
:

${ROOT_DIR}/maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/evaluation/__init__.py添加下列代码

...
...
from .ship_det_airbus import ship_det_airbus_evaluation


def evaluate(dataset, predictions, output_folder, **kwargs):
    ...
    ...
    ...
    elif isinstance(dataset, datasets.AbstractDataset):
        return abs_cityscapes_evaluation(**args)
    # ship_det_airbus
    elif isinstance(dataset, datasets.ShipDetAirbusDataset):    # add this line
        return ship_det_airbus_evaluation(**args)               # add this line
    else:
        dataset_name = dataset.__class__.__name__
        raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name)) 

6. 运行代码

scratch_e2e_faster_rcnn_R_50_FPN_3x_gn_ship_det_airbus
|_ scratch_e2e_faster_rcnn_R_50_FPN_3x_gn.yaml
|_ run.sh

配置文件scratch_e2e_faster_rcnn_R_50_FPN_3x_gn.yaml

INPUT:
  MIN_SIZE_TRAIN: (768,)
  MAX_SIZE_TRAIN: 768
  MIN_SIZE_TEST: 768
  MAX_SIZE_TEST: 768
MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"
  WEIGHT: "" # no pretrained model
  BACKBONE:
    CONV_BODY: "R-50-FPN"
    FREEZE_CONV_BODY_AT: 0 # finetune all layers
  RESNETS: # use GN for backbone
    BACKBONE_OUT_CHANNELS: 256
    STRIDE_IN_1X1: False
    TRANS_FUNC: "BottleneckWithGN"
    STEM_FUNC: "StemWithGN"
  FPN:
    USE_GN: True # use GN for FPN
  RPN:
    USE_FPN: True
    ANCHOR_STRIDE: (4, 8, 16, 32, 64)
    PRE_NMS_TOP_N_TRAIN: 2000
    PRE_NMS_TOP_N_TEST: 1000
    POST_NMS_TOP_N_TEST: 1000
    FPN_POST_NMS_TOP_N_TEST: 1000
  ROI_HEADS:
    USE_FPN: True
    BATCH_SIZE_PER_IMAGE: 512
    POSITIVE_FRACTION: 0.25
  ROI_BOX_HEAD:
    USE_GN: True # use GN for bbox head
    POOLER_RESOLUTION: 7
    POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125)
    POOLER_SAMPLING_RATIO: 2
    FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor"
    PREDICTOR: "FPNPredictor"
DATASETS:
  TRAIN: ("ship_det_airbus_train")
  TEST: ("ship_det_airbus_val",)
DATALOADER:
  SIZE_DIVISIBILITY: 32
SOLVER:
  # Assume 8 gpus
  BASE_LR: 0.02
  WEIGHT_DECAY: 0.0001
  STEPS: (210000, 250000)
  MAX_ITER: 270000
  IMS_PER_BATCH: 16
TEST:
  IMS_PER_BATCH: 8
OUTPUT_DIR: output

运行脚本run.sh

#!/usr/bin/env bash


train_py="{path_to}/tools/train_net.py"
test_py="{path_to}/test_net.py"
config_yaml="scratch_e2e_faster_rcnn_R_50_FPN_3x_gn.yaml"

# single GPU
#python ${train_py} --config-file ${config_yaml}
#python ${test_py} --config-file ${config_yaml}

# multi-GPUs
export NGPUS=8
python -m torch.distributed.launch --nproc_per_node=${NGPUS} ${train_py} --config-file ${config_yaml}
#python -m torch.distributed.launch --nproc_per_node=${NGPUS} ${test_py} --config-file ${config_yaml}

运行

(maskrcnn-benchmark)$ bash run.sh
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值