mmrotate简单使用
介绍
dota数据
mmrotate是一个旋转框的目标检测工具,常用于遥感图像,适用于4个坐标构成的旋转方形框,数据集以dota为例,格式如下:
- Data
- labels
- 0001.txt
- 0002.txt
- images
- 0001.png
- 0002.png
- labels
0001.txt :
143 430 175 380 230 416 198 466 B 0
154 140 185 139 189 183 157 184 A 0
189 138 221 137 225 182 193 183 A 0
257 137 289 135 293 180 261 181 A 0
分别是x1,y1,x2,y2,x3,y3,x4,y4,classes,其中最后一个数字不可缺少,0表示易识别,1表示难识别
mmrotate调整内容
1.调整数据集和类别
(1)在mmrotate/datasets/dota.py内调整classes
(2)在configs的dota.py中调整data_root,dataloader的ann_files,img_path
(3)调整model的head部分的num_classes
2.预训练模型
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', prefix='backbone.', checkpoint=coco_ckpt)),
neck=dict(
init_cfg=dict(type='Pretrained', prefix='neck.',checkpoint=coco_ckpt)),
bbox_head=dict(
init_cfg=dict(type='Pretrained', prefix='bbox_head.', checkpoint=coco_ckpt)))
3.lr与batch_size
optim_wrapper = dict(
optimizer=dict(lr=0.00025, type='AdamW', weight_decay=0.05))
param_scheduler = [
dict(
T_max=16,
begin=0,
by_epoch=True,
convert_to_iter_based=True,
end=16,
eta_min=5e-06,
type='CosineAnnealingLR'),
]
修改dataloader的batch_size即可
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=16, val_interval=2)
开始训练
以rtmdet为例
python ./tools/train.py ./configs/rotated_rtmdet/rtmdet.py
多卡
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 ./tools/train.py ./configs/rotated_rtmdet/rtmdet.py --launcher pytorch
不同版本的test.py有所区别,详见test.py