4.Swin Transformer目标检测——训练数据集

1.centos7 安装显卡驱动、cuda、cudnn-CSDN博客

2.安装conda python库-CSDN博客

3.Cenots Swin-Transformer-Object-Detection环境配置-CSDN博客

步骤1:准备待训练的coco数据集

下载地址:https://download.csdn.net/download/malingyu/88519420

https://download.csdn.net/download/malingyu/88519411

说明:由于数据集比较大,分开两个资源下载

在项目跟目录,新建目录data/coco,将下载的资源直接放到文件夹中

复制test2017,分布为train2017、val2017。

步骤2:修改tools/tran.py文件

其中config添加上默认的路径

def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('--config',default='../configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py', help='train config file path')

步骤3:修改文件configs/_base_/default_runtime.py

添加上下载好的模型路径。

步骤4.修改文件configs/_base_/dataset/coco_instance.py

补充好data_root,和后面的文件夹路径

dataset_type = 'CocoDataset'
data_root = '../data/coco/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])

步骤5:进入tools目录

执行python tran.py文件

运行成功,可以进行数据的训练。

报错问题:TypeError: FormatCode() got an unexpected keyword argument ‘verify’

原因:yapf版本过高

由0.40.2 切换成 0.40.1问题解决

pip install yapf==0.40.1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值