Swin Transformer做主干的 Faster RCNN 目标检测网络(mmdetection)

B站视频教程合集地址Swin Transformer做主干的 Faster RCNN 目标检测网络
 

一、所需软件(包)介绍

  • 项目工程:mmdetection,直接去github拉取代码即可,拉取位置:mmdetection ,确保当前mmdetection版本支持mmcv 1.3.17,因为后面使用的环境是mmcv 1.3.17的,mmdet与mmcv版本对应关系参考:mmdet与mmcv版本 ,如果未来master支持的mmcv版本要求大于1.3.17的话,请按照要求安装对应的版本。
  • 开发环境:与之前 Swin Transformer Object Detection工程所使用的环境相同,安装过程参考:Swin Transformer Object Detection 目标检测-1——环境搭建详细教程

二、环境搭建

  • 如果之前已经创建了 Swin Transformer Object Detection 项目所需的环境的话,可以直接使用,但是会对后面再训练Swin Transformer Object Detection 造成影响(因为mmdetection工程需要对mmdet的版本进行更改才能使用),所以建议再创建一个新的环境给mmdetection使用,或者直接clone一份之前的环境(推荐)。
  • 克隆环境的方式为:conda create -n conda-env2 --clone conda-env1
    • conda-env2 为新创建的环境(从其他环境clone来的)
    • conda-env1 为之前已经有的环境

注:克隆环境需要一段时间,请耐心等待。这样后面我们mmdetection的工程所使用的环境就是新clone的这个。clone 成功后按照下面步骤操作:

  1. 在IDE中配置项目所使用的虚拟环境为我们新克隆的
  2. 进如到虚拟环境后,在mmdetection的项目目录下执行python setup.py develop ,此时确定 mmdet被换成 2.23.0版本。

三、Swin Transformer Faster RCNN 网络结构图

Swin Transformer Faster RCNN 没看到什么官方的名字,索性就这么叫吧。实际上就是Swin Transformer 作为Faster RCNN网络的Backbone(主干特征提取网络)。
在这里插入图片描述

四、Swin Transformer Faster RCNN 网络代码

1. 在configs/swin 目录下新建文件:faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
文件内容如下:
**注意:**训练的epoch在这个文件中改,我直接设置成了50,大家根据需要修改。

_base_ = [
    '../_base_/models/faster_rcnn_swin_fpn.py',
    '../_base_/datasets/faster_rcnn_coco_instance.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

optimizer = dict(
    _delete_=True,
    type='AdamW',
    lr=0.0001,
    betas=(0.9, 0.999),
    weight_decay=0.05,
    paramwise_cfg=dict(
        custom_keys={
            'absolute_pos_embed': dict(decay_mult=0.),
            'relative_position_bias_table': dict(decay_mult=0.),
            'norm': dict(decay_mult=0.)
        }))
lr_config = dict(warmup_iters=1000, step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)

2. 在 configs/base/models 下新建文件:faster_rcnn_swin_fpn.py
文件内容如下:
注意: num_classes 需要根据你数据集的类别进行更改

# model settings
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth'
model = dict(
    type='FasterRCNN',
    backbone=dict(
        type='SwinTransformer',
        embed_dims=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=7,
        mlp_ratio=4,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.,
        attn_drop_rate=0.,
        drop_path_rate=0.2,
        patch_norm=True,
        out_indices=(0, 1, 2, 3),
        with_cp=False,
        convert_weights=True,
        init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
    neck=dict(
        type='FPN',
        in_channels=[96, 192, 384, 768],
        out_channels=256,
        num_outs=5),
    rpn_head=dict(
        type='RPNHead',
        in_channels=256,
        feat_channels=256,
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]),
        bbox_coder=dict(
            type='DeltaXYWHBBoxCoder',
            target_means=[.0, .0, .0, .0],
            target_stds=[1.0, 1.0, 1.0, 1.0]),
        loss_cls=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
        loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
    roi_head=dict(
        type='StandardRoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor',
            roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]),
        bbox_head=dict(
            type='Shared2FCBBoxHead',
            in_channels=256,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=4,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
            loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
    # model training and testing settings
    train_cfg=dict(
        rpn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.7,
                neg_iou_thr=0.3,
                min_pos_iou=0.3,
                match_low_quality=True,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=256,
                pos_fraction=0.5,
                neg_pos_ub=-1,
                add_gt_as_proposals=False),
            allowed_border=-1,
            pos_weight=-1,
            debug=False),
        rpn_proposal=dict(
            nms_pre=2000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssigner',
                pos_iou_thr=0.5,
                neg_iou_thr=0.5,
                min_pos_iou=0.5,
                match_low_quality=False,
                ignore_iof_thr=-1),
            sampler=dict(
                type='RandomSampler',
                num=512,
                pos_fraction=0.25,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=-1,
            debug=False)),
    test_cfg=dict(
        rpn=dict(
            nms_pre=1000,
            max_per_img=1000,
            nms=dict(type='nms', iou_threshold=0.7),
            min_bbox_size=0),
        rcnn=dict(
            score_thr=0.05,
            nms=dict(type='nms', iou_threshold=0.5),
            max_per_img=100)
        # soft-nms is also supported for rcnn testing
        # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
    ))

3. 在/base/datasets 目录下新建文件:faster_rcnn_coco_instance.py
文件内容如下:
注意:

  1. img_scale、samples_per_gpu、 workers_per_gpu可以根据自己的显存大小适当调大、调小
  2. 数据集配置部分参考B站教程:数据集标注
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize', img_scale=(448, 448), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(448, 448),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=4,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_test2017.json',
        img_prefix=data_root + 'test2017/',
        pipeline=test_pipeline))
evaluation = dict(metric=['bbox'])

4. 修改mmdet/datasets/ 下 coco.py
CLASSES中填写自己的分类:例如 CLASSES = ('person', 'bicycle', 'car')
当只有一个类别时,多加一个逗号:CLASSES = ('person',)

五、数据集

数据集依然使用默认的coco格式,数据集制作参考数据集标注(LabelImg、LabelMe使用方法)
注:其实这里是可以使用voc格式的,先挖个坑,后面补上。

六、训练模型

直接执行: python tools/train.py configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py
注意:第一次执行会下载权值文件,需要等待一段时间,或者用特殊办法快点下载,权值文件会自动保存到你的电脑上,下次运行的时候就不再需要重新下载了,当然也可以和之前一样,提前下载好权值文件,然后配置一下。

七、测试训练效果

添加一个自己的图片在demo目录下,

执行:python demo/image_demo.py demo/000071.jpg configs/swin/faster_rcnn_swin_t-p4-w7_fpn_3x_coco.py work_dirs/faster_rcnn_swin_t-p4-w7_fpn_3x_coco/latest.pth

latest.pth 就是自己训练好的最新的权重文件,默认会放在workdir下。

Q & A

Q1. 报错:ImportError: cannot import name ‘init_random_seed’ from ‘mmdet.apis’
A1:进如到虚拟环境后,在mmdetection的项目目录下执行python setup.py develop ,此时 mmdet被换成 2.23.0版本。

 

关于作者:

  • 个人网站:https://beyonderwei.com
  • 邮箱:beyonderwei@gmail.com

微信公众平台

  • 18
    点赞
  • 103
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值