使用Run Transformer Objection 训练自己的数据集

使用Run Transformer Objection 训练自己的数据集

docs

https://github.com/aniie7/Swin-Transformer-Object-Detection

git clone git@github.com:aniie7/Swin-Transformer-Object-Detection.git
数据集准备

SWIN的数据集为COCO格式,我之前的数据是YOLO,所所以需要进行转换。这里先转为VOC,在转为COCO。

之前的目录结构

├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt

YOLO2VOC[抄的]-github

import os, sys
import glob
from PIL import Image
import argparse


def txtLabel_to_xmlLabel(classes_file, source_txt_path, source_img_path, save_xml_path):
    if not os.path.exists(save_xml_path):
        os.makedirs(save_xml_path)
    classes = open(classes_file).read().splitlines()
    print(classes)
    for file in os.listdir(source_txt_path):
        img_name = file.replace('.txt', '.png')
        img_path = os.path.join(source_img_path, file.replace('.txt', '.png'))  # png to jpg
        img_file = Image.open(img_path)
        txt_file = open(os.path.join(source_txt_path, file)).read().splitlines()
        print(txt_file)
        xml_file = open(os.path.join(save_xml_path, file.replace('.txt', '.xml')), 'w')
        width, height = img_file.size
        xml_file.write('<annotation>\n')
        xml_file.write('\t<folder>simple</folder>\n')
        xml_file.write('\t<filename>' + str(img_name) + '</filename>\n')
        xml_file.write('\t<size>\n')
        xml_file.write('\t\t<width>' + str(width) + ' </width>\n')
        xml_file.write('\t\t<height>' + str(height) + '</height>\n')
        xml_file.write('\t\t<depth>' + str(3) + '</depth>\n')
        xml_file.write('\t</size>\n')

        for line in txt_file:
            print(line)
            line_split = line.split(' ')
            x_center = float(line_split[1])
            y_center = float(line_split[2])
            w = float(line_split[3])
            h = float(line_split[4])
            xmax = int((2 * x_center * width + w * width) / 2)
            xmin = int((2 * x_center * width - w * width) / 2)
            ymax = int((2 * y_center * height + h * height) / 2)
            ymin = int((2 * y_center * height - h * height) / 2)

            xml_file.write('\t<object>\n')
            xml_file.write('\t\t<name>' + str(classes[int(line_split[0])]) + '</name>\n')
            xml_file.write('\t\t<pose>Unspecified</pose>\n')
            xml_file.write('\t\t<truncated>0</truncated>\n')
            xml_file.write('\t\t<difficult>0</difficult>\n')
            xml_file.write('\t\t<bndbox>\n')
            xml_file.write('\t\t\t<xmin>' + str(xmin) + '</xmin>\n')
            xml_file.write('\t\t\t<ymin>' + str(ymin) + '</ymin>\n')
            xml_file.write('\t\t\t<xmax>' + str(xmax) + '</xmax>\n')
            xml_file.write('\t\t\t<ymax>' + str(ymax) + '</ymax>\n')
            xml_file.write('\t\t</bndbox>\n')
            xml_file.write('\t</object>\n')
        xml_file.write('</annotation>')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--classes_file', type=str, default="classes.names")
    parser.add_argument('--source_txt_path', type=str,
                        default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/labels")
    parser.add_argument('--source_img_path', type=str,
                        default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/JPEGImages")
    parser.add_argument('--save_xml_path', type=str,
                        default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/VOCAnnotations")
    opt = parser.parse_args()

    txtLabel_to_xmlLabel(opt.classes_file, opt.source_txt_path, opt.source_img_path, opt.save_xml_path)

之后文件目录为

├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt
│ └── VOCAnnotations*.xml

VOC2COCO[抄的]-github
# -*- coding=utf-8 -*-
# !/usr/bin/python

import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET
# 检测框的ID起始值
START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
PRE_DEFINE_CATEGORIES = {"smoke": 0}


# If necessary, pre-define category and its id
#  PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
#  "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
#  "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
#  "motorbike": 14, "person": 15, "pottedplant": 16,
#  "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


def convert(xml_list, xml_dir, json_file):
    '''
    :param xml_list: 需要转换的XML文件列表
    :param xml_dir: XML的存储文件夹
    :param json_file: 导出json文件的路径
    :return: None
    '''
    list_fp = xml_list
    image_id = 1
    # 标注基本结构
    json_dict = {"images": [],
                 "type": "instances",
                 "annotations": [],
                 "categories": []}
    categories = PRE_DEFINE_CATEGORIES
    bnd_id = START_BOUNDING_BOX_ID
    for line in list_fp:
        line = line.strip()
        print(" Processing {}".format(line))
        # 解析XML
        xml_f = os.path.join(xml_dir, line)
        tree = ET.parse(xml_f)
        root = tree.getroot()
        filename = root.find('filename').text
        # 取出图片名字
        image_id += 1
        size = get_and_check(root, 'size', 1)
        # 图片的基本信息
        width = int(get_and_check(size, 'width', 1).text)
        height = int(get_and_check(size, 'height', 1).text)
        image = {'file_name': filename,
                 'height': height,
                 'width': width,
                 'id': image_id}
        json_dict['images'].append(image)
        # 处理每个标注的检测框
        for obj in get(root, 'object'):
            # 取出检测框类别名称
            category = get_and_check(obj, 'name', 1).text
            # 更新类别ID字典
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
            ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
            xmax = int(get_and_check(bndbox, 'xmax', 1).text)
            ymax = int(get_and_check(bndbox, 'ymax', 1).text)
            assert (xmax > xmin)
            assert (ymax > ymin)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            annotation = dict()
            annotation['area'] = o_width * o_height
            annotation['iscrowd'] = 0
            annotation['image_id'] = image_id
            annotation['bbox'] = [xmin, ymin, o_width, o_height]
            annotation['category_id'] = category_id
            annotation['id'] = bnd_id
            annotation['ignore'] = 0
            # 设置分割数据,点的顺序为逆时针方向
            annotation['segmentation'] = [[xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]]

            json_dict['annotations'].append(annotation)
            bnd_id = bnd_id + 1

    # 写入类别ID字典
    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    # 导出到json
    # mmcv.dump(json_dict, json_file)
    print(type(json_dict))
    json_data = json.dumps(json_dict)
    with  open(json_file, 'w') as w:
        w.write(json_data)


if __name__ == '__main__':
    root_path = '/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data'

    if not os.path.exists(os.path.join(root_path, 'mycocodata/annotations')):
        os.makedirs(os.path.join(root_path, 'mycocodata/annotations'))
    if not os.path.exists(os.path.join(root_path, 'mycocodata/train2017')):
        os.makedirs(os.path.join(root_path, 'mycocodata/train2017'))
    if not os.path.exists(os.path.join(root_path, 'mycocodata/val2017')):
        os.makedirs(os.path.join(root_path, 'mycocodata/val2017'))
    xml_dir = os.path.join(root_path, 'VOCAnnotations')  # 已知的VOC2012的标注

    xml_labels = os.listdir(xml_dir)
    np.random.shuffle(xml_labels)
    split_point = int(len(xml_labels) / 10)

    # validation data
    xml_list = xml_labels[0:split_point]
    json_file = os.path.join(root_path, 'mycocodata/annotations/detections_val2017.json')
    convert(xml_list, xml_dir, json_file)
    for xml_file in xml_list:
        img_name = xml_file[:-4] + '.png'  
        shutil.copy(os.path.join(root_path, 'JPEGImages', img_name),
                    os.path.join(root_path, 'mycocodata/val2017', img_name))
    # train data
    xml_list = xml_labels[split_point:]
    json_file = os.path.join(root_path, 'smoke_coco/annotations/detections_train2017.json')
    convert(xml_list, xml_dir, json_file)
    for xml_file in xml_list:
        img_name = xml_file[:-4] + '.png'
        shutil.copy(os.path.join(root_path, 'JPEGImages', img_name),
                    os.path.join(root_path, 'mycocodata/train2017', img_name))

这时文件目录为

├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt
│ ├── smoke_coco
│ │ ├── annotations:*.json
│ │ ├── train2017:*.png
│ │ └── val2017:*.png
│ └── VOCAnnotations*.xml

数据集的正确是很重要的,在准备好之后可以验证一下,不了解COCO数据集标注格式的可以参考:coco数据集介绍

配置文件准备

  • configs/swin/下的一个配置文件,这个的选择指定了下面几个文件的。
  • configs/_base_/datasets/coco_detection.py
  • configs/_base_/models/cascade_mask_rcnn_swin_fpn.py
  • configs/_base_/default_runtime.py
configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py

其中首行的_base_列表包含了模型配置文件、数据集配置配置文件,训练参数(lr等)、有关训练配置文件。

_base_ = [
    '../_base_/models/cascade_mask_rcnn_swin_fpn.py',
    '../_base_/datasets/coco_detection.py',
    '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]

首先把这个文件里的num_classes=80改为自己数据集所含的类数,这个文件共3处。

文件后部

optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
                                                 'relative_position_bias_table': dict(decay_mult=0.),
                                                 'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[27, 33])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)

这里的max_epochs=36为训练次数?

configs/_base_/datasets/coco_detection.py

这个文件配置数据集目录,训练策略,如果数据的目录名称结构和下面文件描述的不同,改其一即可。samples_per_gpu=2,大概是batchsize的意思,如果OOM可以将其改小,workers_per_gpu=2,应该是数据加载时的参数,Ubuntu设2是没问题的。

其中带#Add注释的是为解决某个ERROR而添加。

dataset_type = 'CocoDataset'
data_root = 'data/smoke_coco/'
...
...
data = dict(
    samples_per_gpu=2,
    workers_per_gpu=2,
    train=dict(
        type=dataset_type,
        classes=('smoke',),  # Add
        ann_file=data_root + 'annotations/detections_train2017.json',
        img_prefix=data_root + 'train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        classes=('smoke',),  # Add
        ann_file=data_root + 'annotations/detections_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        classes=('smoke',),  # Add
        ann_file=data_root + 'annotations/detections_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
configs/_base_/models/cascade_mask_rcnn_swin_fpn.py

这里为model settings,需要更改的是把classes_num改为我们需要的,该文件共3处。

configs/_base_/default_runtime.py

checkpoint_config = dict(interval=5):每隔几(5)个Epoch保存一次权重文件。

load_from = "checkpoints/cascade_mask_rcnn_swin_small_patch4_window7.pth":指定预训练文件加载方式。这里也可以使用命令行参数指定,但我一开始使用的时候出错,可能不是这个原因。

mmdet/datasets/coco.py

修改内部的CLASSES = ('Your Classes')

TRAIN

单GPU

python tools/train.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py

多GPU

tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 3

ERRORS

1、KeyError:
KeyError: 'SwinTransformer is not in the backbone registry'

https://github.com/microsoft/Swin-Transformer/issues/95

i uninstall yacs and reinstall yacs==0.8.1 sloved

试了,但我似乎也没有安装这个包

https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/issues/9

尽量避免跑上面那一份代码,否则可能会导致一些奇怪的问题,建议作者在 README 里说明一下。

试了,没解决

2、subprocess.CalledProcessError
subprocess.CalledProcessError: Command '['/home/qiucm/anaconda3/envs/swin/bin/python', '-u', 'tools/train.py', '--local_rank=2', 'configs/swin/aniie_swin.py']' returned non-zero exit status 1

这个报错应该是多卡训练的时候出现的。但错误信息行该不在这一行,


3、AssertionError
'AssertionError: The `num_classes` (1) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 5) in CocoDataset

https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/issues/44

ffletcherr

you changed all num_classes=80 to num_classes=1 but Likely had a syntax issues in dataset config file. The correct way to modify dataset config file is (comma after class name is important) :

train=dict(
        type=dataset_type,
        # add this line :
        classes = ('yourClass1', 'yourClass2'), # or for One Class :  ('yourClass1',), Notice that comma in necessery 
        ann_file=data_root + 'annotations/instances_train2017.json',
        img_prefix=data_root + 'train2017/',
        seg_prefix=data_root + 'stuffthingmaps/train2017/',
        pipeline=train_pipeline),
    val=dict(
        type=dataset_type,
        # and this line :
        classes = ('yourClass1', 'yourClass2'), # or for One Class :  ('yourClass1',), Notice that comma in necessery 
        ann_file=data_root + 'annotations/instances_val2017.json',
        img_prefix=data_root + 'val2017/',
        pipeline=test_pipeline),

修改文件configs/_base_/datasets/coco_detection.py文件里的

train=dict(
    type=dataset_type,
    classes=('smoke',),  # ADD LINE
    ann_file=data_root + 'annotations/detections_train2017.json',
    img_prefix=data_root + 'train2017/',
    pipeline=train_pipeline),
val=dict(
    type=dataset_type,
    classes=('smoke',),  # ADD LINE
    ann_file=data_root + 'annotations/detections_val2017.json',
    img_prefix=data_root + 'val2017/',
    pipeline=test_pipeline),
test=dict(
    type=dataset_type,
    classes=('smoke',),  # ADD LINE
    ann_file=data_root + 'annotations/detections_val2017.json',
    img_prefix=data_root + 'val2017/',
    pipeline=test_pipeline)

即使只有一类,类名后的,也是必要的

### 回答1: 要使用Swin Transformer训练自己的数据集,需要进行以下步骤: 1. 准备数据集:将自己的数据集准备好,并将其划分为训练集、验证集和测试集。 2. 安装Swin Transformer:在本地或云端安装Swin Transformer,可以使用PyTorch框架进行安装。 3. 配置训练参数:根据自己的数据集和需求,配置训练参数,如学习率、批次大小、训练轮数等。 4. 定义模型:根据自己的数据集和需求,定义Swin Transformer模型,可以使用训练模型进行fine-tuning。 5. 训练模型:使用定义好的模型和训练参数,对数据集进行训练。 6. 评估模型:使用验证集和测试集对训练好的模型进行评估,可以计算准确率、召回率、F1值等指标。 7. 预测新数据:使用训练好的模型对新数据进行预测,可以得到分类结果或回归结果。 以上是使用Swin Transformer训练自己的数据集的基本步骤,具体实现需要根据自己的需求进行调整。 ### 回答2: 总体来说,针对个人数据集进行Swin Transformer模型的训练需要遵循以下步骤: 1. 数据集准备:首先,需要准备好数据集数据集的准备需要注意的是要有标签数据集,保持数据集的质量高,数据集中的类别要明确,数量要充足。同时,对于数据集中的图像,可以进行预处理操作如裁剪、缩放、翻转等,以适应模型的要求。 2. 划分训练集和测试集:在准备数据集的时候,要将数据集按照训练集和测试集进行划分。通常,可将数据集中的70%作为训练集,30%作为测试集。 3. 数据集加载与预处理:在PyTorch中,可以使用DataLoader来将数据集加载到模型中,并进行数据预处理如归一化等操作,同时还可以设置batch_size、shuffle等参数。 4. 定义模型:使用PyTorch中的Swin Transformer模型,并进行自定义修改以适应自己数据集的处理任务。 5. 定义损失函数与优化器:根据任务目标不同,选择不同的损失函数如交叉熵、均方误差等,并结合优化器如Adam、SGD等进行模型训练。 6. 训练模型:使用DataLoader加载数据集,应用损失函数和优化器,训练模型。可以设置迭代次数、学习率等参数,并进行学习率衰减等技巧来提高模型效果。 7. 模型评估:在训练模型过程中,需要了解模型的表现,可以使用测试集数据集进行模型评估。常用指标如准确率、精确率、召回率、F1值等。 8. 模型调参与优化:根据测试集的表现调整模型参数,如学习率、batch_size等,同时还可以进行模型结构的优化等操作以提高模型的性能。 总的来说,这些步骤有助于构建适用于自己数据集的Swin Transformer模型,并可以及时了解模型的表现及进行调优,从而提高模型的性能。当然,虽然操作流程有些繁琐,但是得到高质量的模型肯定值得一试。 ### 回答3: Swin Transformer是近期提出的一种先进的图像分类模型,在多个视觉领域的任务中都取得了最佳效果。在使用Swin Transformer之前,我们需要准备一个自己的数据集,并利用该数据集对Swin Transformer进行训练以实现图像分类任务。 下面是使用Swin Transformer训练自己的数据集的步骤: 1. 数据预处理:收集并准备数据集,对图片进行裁剪、缩放、旋转等增强操作,同时保证标签信息准确无误。 2. 安装并配置Swin Transformer:在训练之前,需要安装Swin Transformer的相关包并配置环境,例如PyTorch,Torchvision,Pillow等。 3. 训练模型:Swin Transformer是深度神经网络,因此需要大量的计算资源和时间才能完成模型训练。我们需要选择合适的GPU,设置合适的参数,使用训练集进行模型训练。 4. 模型评估:训练完成之后,我们需要将测试集输入到模型中,并计算模型的准确率和损失等指标,以评估模型的性能。 5. 调整模型参数:如果模型性能不理想,我们可以尝试调整模型参数,例如修改Swin Transformer的网络层数、神经元数目等,直到获得最佳的结果。 6. 应用模型:最后,我们可以使用Swin Transformer模型对新的图片进行分类预测,对图像进行分类。 总之,Swin Transformer是一种先进的图像分类模型,能够有效地识别和分类图像。对于训练自己的数据集,我们需要按照上述步骤进行操作,以获得最佳的模型性能。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值