使用Run Transformer Objection 训练自己的数据集
https://github.com/aniie7/Swin-Transformer-Object-Detection
git clone git@github.com:aniie7/Swin-Transformer-Object-Detection.git
数据集准备
SWIN的数据集为COCO格式,我之前的数据是YOLO,所所以需要进行转换。这里先转为VOC,在转为COCO。
之前的目录结构
├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt
YOLO2VOC[抄的]-github
import os, sys
import glob
from PIL import Image
import argparse
def txtLabel_to_xmlLabel(classes_file, source_txt_path, source_img_path, save_xml_path):
if not os.path.exists(save_xml_path):
os.makedirs(save_xml_path)
classes = open(classes_file).read().splitlines()
print(classes)
for file in os.listdir(source_txt_path):
img_name = file.replace('.txt', '.png')
img_path = os.path.join(source_img_path, file.replace('.txt', '.png')) # png to jpg
img_file = Image.open(img_path)
txt_file = open(os.path.join(source_txt_path, file)).read().splitlines()
print(txt_file)
xml_file = open(os.path.join(save_xml_path, file.replace('.txt', '.xml')), 'w')
width, height = img_file.size
xml_file.write('<annotation>\n')
xml_file.write('\t<folder>simple</folder>\n')
xml_file.write('\t<filename>' + str(img_name) + '</filename>\n')
xml_file.write('\t<size>\n')
xml_file.write('\t\t<width>' + str(width) + ' </width>\n')
xml_file.write('\t\t<height>' + str(height) + '</height>\n')
xml_file.write('\t\t<depth>' + str(3) + '</depth>\n')
xml_file.write('\t</size>\n')
for line in txt_file:
print(line)
line_split = line.split(' ')
x_center = float(line_split[1])
y_center = float(line_split[2])
w = float(line_split[3])
h = float(line_split[4])
xmax = int((2 * x_center * width + w * width) / 2)
xmin = int((2 * x_center * width - w * width) / 2)
ymax = int((2 * y_center * height + h * height) / 2)
ymin = int((2 * y_center * height - h * height) / 2)
xml_file.write('\t<object>\n')
xml_file.write('\t\t<name>' + str(classes[int(line_split[0])]) + '</name>\n')
xml_file.write('\t\t<pose>Unspecified</pose>\n')
xml_file.write('\t\t<truncated>0</truncated>\n')
xml_file.write('\t\t<difficult>0</difficult>\n')
xml_file.write('\t\t<bndbox>\n')
xml_file.write('\t\t\t<xmin>' + str(xmin) + '</xmin>\n')
xml_file.write('\t\t\t<ymin>' + str(ymin) + '</ymin>\n')
xml_file.write('\t\t\t<xmax>' + str(xmax) + '</xmax>\n')
xml_file.write('\t\t\t<ymax>' + str(ymax) + '</ymax>\n')
xml_file.write('\t\t</bndbox>\n')
xml_file.write('\t</object>\n')
xml_file.write('</annotation>')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--classes_file', type=str, default="classes.names")
parser.add_argument('--source_txt_path', type=str,
default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/labels")
parser.add_argument('--source_img_path', type=str,
default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/JPEGImages")
parser.add_argument('--save_xml_path', type=str,
default="/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data/VOCAnnotations")
opt = parser.parse_args()
txtLabel_to_xmlLabel(opt.classes_file, opt.source_txt_path, opt.source_img_path, opt.save_xml_path)
之后文件目录为
├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt
│ └── VOCAnnotations*.xml
VOC2COCO[抄的]-github
# -*- coding=utf-8 -*-
# !/usr/bin/python
import sys
import os
import shutil
import numpy as np
import json
import xml.etree.ElementTree as ET
# 检测框的ID起始值
START_BOUNDING_BOX_ID = 1
# 类别列表无必要预先创建,程序中会根据所有图像中包含的ID来创建并更新
PRE_DEFINE_CATEGORIES = {"smoke": 0}
# If necessary, pre-define category and its id
# PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
# "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
# "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
# "motorbike": 14, "person": 15, "pottedplant": 16,
# "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
def get(root, name):
vars = root.findall(name)
return vars
def get_and_check(root, name, length):
vars = root.findall(name)
if len(vars) == 0:
raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
if length > 0 and len(vars) != length:
raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
if length == 1:
vars = vars[0]
return vars
def convert(xml_list, xml_dir, json_file):
'''
:param xml_list: 需要转换的XML文件列表
:param xml_dir: XML的存储文件夹
:param json_file: 导出json文件的路径
:return: None
'''
list_fp = xml_list
image_id = 1
# 标注基本结构
json_dict = {"images": [],
"type": "instances",
"annotations": [],
"categories": []}
categories = PRE_DEFINE_CATEGORIES
bnd_id = START_BOUNDING_BOX_ID
for line in list_fp:
line = line.strip()
print(" Processing {}".format(line))
# 解析XML
xml_f = os.path.join(xml_dir, line)
tree = ET.parse(xml_f)
root = tree.getroot()
filename = root.find('filename').text
# 取出图片名字
image_id += 1
size = get_and_check(root, 'size', 1)
# 图片的基本信息
width = int(get_and_check(size, 'width', 1).text)
height = int(get_and_check(size, 'height', 1).text)
image = {'file_name': filename,
'height': height,
'width': width,
'id': image_id}
json_dict['images'].append(image)
# 处理每个标注的检测框
for obj in get(root, 'object'):
# 取出检测框类别名称
category = get_and_check(obj, 'name', 1).text
# 更新类别ID字典
if category not in categories:
new_id = len(categories)
categories[category] = new_id
category_id = categories[category]
bndbox = get_and_check(obj, 'bndbox', 1)
xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
xmax = int(get_and_check(bndbox, 'xmax', 1).text)
ymax = int(get_and_check(bndbox, 'ymax', 1).text)
assert (xmax > xmin)
assert (ymax > ymin)
o_width = abs(xmax - xmin)
o_height = abs(ymax - ymin)
annotation = dict()
annotation['area'] = o_width * o_height
annotation['iscrowd'] = 0
annotation['image_id'] = image_id
annotation['bbox'] = [xmin, ymin, o_width, o_height]
annotation['category_id'] = category_id
annotation['id'] = bnd_id
annotation['ignore'] = 0
# 设置分割数据,点的顺序为逆时针方向
annotation['segmentation'] = [[xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin]]
json_dict['annotations'].append(annotation)
bnd_id = bnd_id + 1
# 写入类别ID字典
for cate, cid in categories.items():
cat = {'supercategory': 'none', 'id': cid, 'name': cate}
json_dict['categories'].append(cat)
# 导出到json
# mmcv.dump(json_dict, json_file)
print(type(json_dict))
json_data = json.dumps(json_dict)
with open(json_file, 'w') as w:
w.write(json_data)
if __name__ == '__main__':
root_path = '/home/qiucm/lan/data/transformer/Swin-Transformer-Object-Detection/data'
if not os.path.exists(os.path.join(root_path, 'mycocodata/annotations')):
os.makedirs(os.path.join(root_path, 'mycocodata/annotations'))
if not os.path.exists(os.path.join(root_path, 'mycocodata/train2017')):
os.makedirs(os.path.join(root_path, 'mycocodata/train2017'))
if not os.path.exists(os.path.join(root_path, 'mycocodata/val2017')):
os.makedirs(os.path.join(root_path, 'mycocodata/val2017'))
xml_dir = os.path.join(root_path, 'VOCAnnotations') # 已知的VOC2012的标注
xml_labels = os.listdir(xml_dir)
np.random.shuffle(xml_labels)
split_point = int(len(xml_labels) / 10)
# validation data
xml_list = xml_labels[0:split_point]
json_file = os.path.join(root_path, 'mycocodata/annotations/detections_val2017.json')
convert(xml_list, xml_dir, json_file)
for xml_file in xml_list:
img_name = xml_file[:-4] + '.png'
shutil.copy(os.path.join(root_path, 'JPEGImages', img_name),
os.path.join(root_path, 'mycocodata/val2017', img_name))
# train data
xml_list = xml_labels[split_point:]
json_file = os.path.join(root_path, 'smoke_coco/annotations/detections_train2017.json')
convert(xml_list, xml_dir, json_file)
for xml_file in xml_list:
img_name = xml_file[:-4] + '.png'
shutil.copy(os.path.join(root_path, 'JPEGImages', img_name),
os.path.join(root_path, 'mycocodata/train2017', img_name))
这时文件目录为
├── data
│ ├── JPEGImages:/*.png
│ ├── labels:*.txt
│ ├── smoke_coco
│ │ ├── annotations:*.json
│ │ ├── train2017:*.png
│ │ └── val2017:*.png
│ └── VOCAnnotations*.xml
数据集的正确是很重要的,在准备好之后可以验证一下,不了解COCO数据集标注格式的可以参考:coco数据集介绍
配置文件准备
configs/swin/
下的一个配置文件,这个的选择指定了下面几个文件的。configs/_base_/datasets/coco_detection.py
configs/_base_/models/cascade_mask_rcnn_swin_fpn.py
configs/_base_/default_runtime.py
configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py
其中首行的_base_
列表包含了模型配置文件、数据集配置配置文件,训练参数(lr等)、有关训练配置文件。
_base_ = [
'../_base_/models/cascade_mask_rcnn_swin_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py'
]
首先把这个文件里的num_classes=80
改为自己数据集所含的类数,这个文件共3处。
文件后部
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)}))
lr_config = dict(step=[27, 33])
runner = dict(type='EpochBasedRunnerAmp', max_epochs=36)
这里的max_epochs=36
为训练次数?
configs/_base_/datasets/coco_detection.py
这个文件配置数据集目录,训练策略,如果数据的目录名称结构和下面文件描述的不同,改其一即可。samples_per_gpu=2,
大概是batchsize的意思,如果OOM可以将其改小,workers_per_gpu=2,
应该是数据加载时的参数,Ubuntu设2是没问题的。
其中带#Add
注释的是为解决某个ERROR而添加。
dataset_type = 'CocoDataset'
data_root = 'data/smoke_coco/'
...
...
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
classes=('smoke',), # Add
ann_file=data_root + 'annotations/detections_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
classes=('smoke',), # Add
ann_file=data_root + 'annotations/detections_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
classes=('smoke',), # Add
ann_file=data_root + 'annotations/detections_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
configs/_base_/models/cascade_mask_rcnn_swin_fpn.py
这里为model settings,需要更改的是把classes_num改为我们需要的,该文件共3处。
configs/_base_/default_runtime.py
checkpoint_config = dict(interval=5)
:每隔几(5)个Epoch保存一次权重文件。
load_from = "checkpoints/cascade_mask_rcnn_swin_small_patch4_window7.pth"
:指定预训练文件加载方式。这里也可以使用命令行参数指定,但我一开始使用的时候出错,可能不是这个原因。
mmdet/datasets/coco.py
修改内部的CLASSES = ('Your Classes')
。
TRAIN
单GPU
python tools/train.py configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py
多GPU
tools/dist_train.sh configs/swin/cascade_mask_rcnn_swin_small_patch4_window7_mstrain_480-800_giou_4conv1f_adamw_3x_coco.py 3
ERRORS
1、KeyError:
KeyError: 'SwinTransformer is not in the backbone registry'
https://github.com/microsoft/Swin-Transformer/issues/95
i uninstall yacs and reinstall yacs==0.8.1 sloved
试了,但我似乎也没有安装这个包
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/issues/9
尽量避免跑上面那一份代码,否则可能会导致一些奇怪的问题,建议作者在 README 里说明一下。
试了,没解决
2、subprocess.CalledProcessError
subprocess.CalledProcessError: Command '['/home/qiucm/anaconda3/envs/swin/bin/python', '-u', 'tools/train.py', '--local_rank=2', 'configs/swin/aniie_swin.py']' returned non-zero exit status 1
这个报错应该是多卡训练的时候出现的。但错误信息行该不在这一行,
3、AssertionError
'AssertionError: The `num_classes` (1) in Shared2FCBBoxHead of MMDataParallel does not matches the length of `CLASSES` 5) in CocoDataset
https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/issues/44
you changed all
num_classes=80
tonum_classes=1
but Likely had a syntax issues in dataset config file. The correct way to modify dataset config file is (comma after class name is important) :train=dict( type=dataset_type, # add this line : classes = ('yourClass1', 'yourClass2'), # or for One Class : ('yourClass1',), Notice that comma in necessery ann_file=data_root + 'annotations/instances_train2017.json', img_prefix=data_root + 'train2017/', seg_prefix=data_root + 'stuffthingmaps/train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, # and this line : classes = ('yourClass1', 'yourClass2'), # or for One Class : ('yourClass1',), Notice that comma in necessery ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline),
修改文件configs/_base_/datasets/coco_detection.py
文件里的
train=dict(
type=dataset_type,
classes=('smoke',), # ADD LINE
ann_file=data_root + 'annotations/detections_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
classes=('smoke',), # ADD LINE
ann_file=data_root + 'annotations/detections_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
classes=('smoke',), # ADD LINE
ann_file=data_root + 'annotations/detections_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline)
即使只有一类,类名后的,
也是必要的