主要使用swin trasnsformer试了一下sar图像的目标检测,用了舰船ssdd数据集和地面目标MSTAR数据集。
MMDet安装
MMDet地址:https://github.com/open-mmlab/mmdetection
直接pull下来后按照官方文档进行安装环境即可。
记得如果克隆环境或转移到别的环境,需要重新setup一下
python setup.py develop
Swin Transformer代码
1.创建configs下的配置文件
configs/swin下创建一个faster_rcnn_swin_l-p4-w12_coco.py
在这个文件中可以修改学习率、迭代次数等参数。
_base_ = [
'../_base_/models/faster_rcnn_swin_large_fpn.py',
'../_base_/datasets/faster_rcnn_coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
optimizer = dict(
_delete_=True,
type='AdamW',
# lr=0.0001,
lr=0.000051,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)
2.创建model文件
在/base/models/中新建faster_rcnn_swin_large_fpn.py文件
在文件中可以修改网络backbone、neck等配置,这里使用swin的large模型,PAFPN为neck。
# model settings
pretrained = 'D:/Project/mmdetection-master/checkpoints/swin_large_patch4_window12_384_22k.pth'
# 1. ROI 0.5-0.7
# 2. pafpn
# 3. albu_train_transforms
# 4. 多尺度
model = dict(
type='FasterRCNN',
backbone=dict(
type='SwinTransformer',
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(0, 1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='PAFPN',
in_channels=[192, 384, 768, 1536],
out_channels=256,
num_outs=5),
rpn_head=dict(
type='RPNHead',
in_channels=256,
feat_channels=256,
anchor_generator=dict(