基于mmdetection训练Swin Transformer Object Detection

mmdetection官方文档
环境搭建
docker
找了一个torch版本为1.5.1+cu101的docker环境,然后安装mmdetection环境

pip install mmcv-full
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection-master
pip install -r requirements/build.txt
pip install -v -e .

安装apex

git clone https://github.com/NVIDIA/apex
cd apex
pip install -r requirements.txt
python setup.py install --cpp_ext

安装成功

Processing dependencies for apex==0.1
Finished processing dependencies for apex==0.1
  • backbone:mmdet/models/backbones
  • neck:mmdet/models/necks
  • head:mmdet/models/roi_heads
  • BBox Assigner:mmdet/core/bbox/assigners
  • BBox Sampler:mmdet/core/bbox/samplers
  • BBox Encoder:mmdet/core/bbox/coder
  • BBox Decoder:mmdet/core/bbox/coder
  • Loss:mmdet/models/losses
  • BBox PostProcess:mmdet/core/post_processing

在"Swin-Transformer-Object-Detection-master/configs/swin/"目录下,可以看到模型文件,选择对应的修改
以"cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py"为例:

# head为例
roi_head=dict(
        bbox_head=[
            dict(
                type='ConvFCBBoxHead',
                num_shared_convs=4,
                num_shared_fcs=1,
                in_channels=256,
                conv_out_channels=256,
                fc_out_channels=1024,
                roi_feat_size=7,
                num_classes=15,  # 修改类别数量
# 根据gpu的数量,使用合适的BN
# norm_cfg=dict(type='SyncBN', requires_grad=True),
norm_cfg=dict(type='BN', requires_grad=True),

# 调整学习率等相关参数,lr = 0.00125*batch_size
optimizer = dict(_delete_=True, type='AdamW', lr=0.00125, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
# 修改epoch
runner = dict(type='EpochBasedRunner', max_epochs=20)                                                  
# 不适用fp16,将use_fp16改为False
optimizer_config = dict(
    type="DistOptimizerHook",
    update_interval=1,
    grad_clip=None,
    coalesce=True,
    bucket_size_mb=-1,
    use_fp16=False,
)

在"configs/base/datasets/coco_instance.py"中根据需要修改

# 修改数据集的类型,路径
dataset_type = 'CocoDataset'
data_root = '/home/coco/'

# 修改img_size等参数,CUDA out of memory时可以修改
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    # 原本为1333*800
    #dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='Resize', img_scale=(416, 416), keep_ratio=True),

# 修改batch_size
data = dict(
    samples_per_gpu=1, # 每块GPU上的sample个数,batch_size = gpu数目*该参数
    workers_per_gpu=1, # 每块GPU上的workers的个数
    # 以train为例
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json', # 标注路径
        img_prefix=data_root + 'train2017/', # 训练图片路径
        pipeline=train_pipeline),

修改类别:mmdet/datasets/coco.py和 mmdet/core/evaluation/class_names.py文件

class CocoDataset(CustomDataset):

    #CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    #           'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    #           'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
    #           'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
    #           'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
    #           'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
    #           'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    #           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    #           'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
    #           'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
    #           'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
    #           'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
    #           'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
    #           'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
    CLASSES = ('person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car')
def coco_classes():
    return ['person', 'tool_vehicle', 'bicycle', 'motorbike', 'pedal_tricycle', 'car', 'passenger_car',
         'truck', 'police_car', 'ambulance', 'bus', 'dump_truck', 'tanker', 'roadblock', 'fire_car']

修改"./tools/train.py"文件

# 选取其中一种版本,单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')

模型预训练,权重加载、保存参数,config/base/default_runtime.py文件

checkpoint_config = dict(interval=1) # 每训练一个epoch,保存一次权重
load_from = None # 加载backbone权重
resume_from = None # 继续训练

训练模型
使用编号为3的单个gpu训练

python ./tools/train.py configs/swin/cascade_mask_rcnn_swin_base_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py --gpu-ids 3

使用多gpu训练

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 4

训练Log及权重
保存在"Swin-Transformer-Object-Detection-master/work_dirs/"中

coco测试

python tools/test.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py cascade_mask_rcnn_swin_small_patch4_window7.pth --eval segm

输出demo,输出为cls,x1,y1,x2,y2的txt格式

from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector
import numpy as np
import os
from tqdm import tqdm

def main():
    parser = ArgumentParser()
    parser.add_argument('--img-path', default='/data/wj/test/',help='Image file')
    parser.add_argument('--config', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/cascade_rcnn_x101_64x4d_fpn_20e_coco.py' ,help='Config file')
    parser.add_argument('--checkpoint', default='../work_dirs/cascade_rcnn_x101_64x4d_fpn_20e_coco/latest.pth', help='Checkpoint file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    args = parser.parse_args()
    imgs_path = args.img_path
    save_path = '../output/'

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    for img_path in tqdm(os.listdir(imgs_path)):
        img = os.path.join(imgs_path, img_path)
        result = inference_detector(model, img)
        bboxes = np.vstack(result)
        labels = [
            np.full(bbox.shape[0], i, dtype=np.int32)
            for i, bbox in enumerate(result)
        ]
        labels = np.concatenate(labels)
        score_thr = args.score_thr
        if score_thr > 0:
            assert bboxes.shape[1] == 5
            scores = bboxes[:, -1]
            inds = scores > score_thr
            bboxes = bboxes[inds, :]
            labels = labels[inds]
        if len(bboxes) == 0:
            txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
            with open(txt_path, 'w') as f:
                f.write("")
        for i, (bbox, label) in enumerate(zip(bboxes, labels)):
            bbox_int = bbox.astype(np.int32)
            x1, y1, x2, y2, conf = bbox_int
            txt_path = os.path.join(save_path, '{}.txt'.format(img_path.split('.')[0]))
            with open(txt_path, 'a') as f:
                f.write("{} {} {} {} {}\n".format(label, x1, y1, x2, y2))

踩过的坑及解决方案:
error with env var RANK
参考:
轻松掌握 MMDetection 整体构建流程(一)

  • 10
    点赞
  • 72
    收藏
    觉得还不错? 一键收藏
  • 32
    评论
### 回答1: Swin Transformer 目标检测是一种基于 Swin Transformer 模型的目标检测算法。它采用了一种新的 Transformer 架构,能够在保持高精度的同时,大幅提高计算效率。该算法在 COCO 数据集上取得了 SOTA 的结果。 ### 回答2: Swin Transformer是一种基于Transformer架构的新型神经网络模型,在目标检测任务中表现出色。它的设计思路主要是通过分解高分辨率特征图的位置编码,将计算复杂度从O(N^2)降低到O(N),极大地提高了模型的计算效率。 Swin Transformer在目标检测任务上的应用主要通过两个关键方面来进行:Swin Transformer Backbone和Swin Transformer FPN。 Swin Transformer Backbone是指将Swin Transformer应用于骨干网络的部分。传统的目标检测模型通常使用ResNet或者EfficientNet等CNN架构作为骨干网络,而Swin Transformer通过将Transformer的自注意力机制应用于骨干网络中,使得模型可以更好地学习到不同尺度和位置的特征信息。 Swin Transformer FPN则是指利用Swin Transformer模型中的特征金字塔网络(Feature Pyramid Network)来进行目标检测。特征金字塔网络通过将不同层次的特征图进行融合,使得模型可以同时获得高级语义信息和低级细节信息,从而提升目标检测的准确性和鲁棒性。 相比于传统的目标检测模型,Swin Transformer在计算效率和准确性上都有显著的提升。它不仅在COCO数据集上取得了当前最好的单模型性能,而且在推理速度上也优于其他同等性能的模型。因此,Swin Transformer在目标检测领域具有广泛的应用前景。 ### 回答3: Swin Transformer是一种基于Transformers的对象检测模型。它是在Transformer架构上进行了改进和优化,以适用于目标检测任务。 与传统的卷积神经网络不同,Swin Transformer使用的是一种局部注意力机制,它能够在图像中进行局部区域的特征提取和交互。这种局部注意力机制能够有效地减少计算复杂度,提升模型的性能。 Swin Transformer利用了一个分层的网络结构,其中每个层级都有多个Swin Transformer块。每个Swin Transformer块由两个子层组成,分别是局部窗口注意力机制子层和跨窗口注意力机制子层。局部窗口注意力机制子层用于提取特定区域的局部特征,而跨窗口注意力机制子层用于不同区域之间的特征交互。 在训练过程中,Swin Transformer还使用了分布式权重梯度传播算法,以加快训练速度。此外,Swin Transformer还采用了数据增强技术,如随机缩放和水平翻转,以提高模型的泛化能力。 实验证明,Swin Transformer在COCO数据集上取得了很好的性能,在目标检测任务上超过了传统的卷积神经网络模型。它在准确性和效率方面表现优异,对于大规模的对象检测任务具有很高的可扩展性。 总之,Swin Transformer是一种基于Transformers的对象检测模型,通过优化的局部注意力机制和分布式训练算法,能够在目标检测任务中取得出色的性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 32
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值