1.代码地址:
https://github.com/xingyizhou/CenterNet
2.编译
按照 https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md 顺序安装
(1)修改/anaconda3/envs/CenterNet/lib/python3.6/site-packages/torch/nn下的torch/nn/functional.py文件找到torch.batch_norm,替换torch.backends.cudnn.enabled
with False
(2)torch要装1.0版本,
conda install pytorch==1.0.0 torchvision==0.2.1 cuda100 -c pytorch
(3)git clone
https://github.com/CharlesShang/DCNv2到$CenterNet_ROOT/src/lib/models/networks/DCNv2,不需要checkout到0.4版本
./make.sh
3.自己数据集格式
训练使用的是json文件,json文件制作脚本,其中,xml2xml.py是为了将不同的类转为同一类,或者在标注了很多类的xml文件中提取出自己要训练的类别,生成新的xml文件。
(1)xml2json.py
import xml.etree.ElementTree as ET
import os
import json
coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []
category_set = dict()
image_set = set()
category_item_id = 0
image_id = 20200000000
annotation_id = 0
def addCatItem(name):
global category_item_id
category_item = dict()
category_item['supercategory'] = 'none'
category_item_id += 1
category_item['id'] = category_item_id
category_item['name'] = name
coco['categories'].append(category_item)
category_set[name] = category_item_id
return category_item_id
def addImgItem(file_name, size):
global image_id
if file_name is None:
raise Exception('Could not find filename tag in xml file.')
if size['width'] is None:
raise Exception('Could not find width tag in xml file.')
if size['height'] is None:
raise Exception('Could not find height tag in xml file.')
image_id += 1
image_item = dict()
image_item['id'] = image_id
image_item['file_name'] = file_name
image_item['width'] = size['width']
image_item['height'] = size['height']
coco['images'].append(image_item)
image_set.add(file_name)
return image_id
def addAnnoItem(object_name, image_id, category_id, bbox):
global annotation_id
annotation_item = dict()
annotation_item['segmentation'] = []
seg = []
#bbox[] is x,y,w,h
#left_top
seg.append(bbox[0])
seg.append(bbox[1])
#left_bottom
seg.append(bbox[0])
seg.append(bbox[1] + bbox[3])
#right_bottom
seg.append(bbox[0] + bbox[2])
seg.append(bbox[1] + bbox[3])
#right_top
seg.append(bbox[0] + bbox[2])
seg.append(bbox[1])
annotation_item['segmentation'].append(seg)
annotation_item['area'] = bbox[2] * bbox[3]
annotation_item['iscrowd'] = 0
annotation_item['ignore'] = 0
annotation_item['image_id'] = image_id
annotation_item['bbox'] = bbox
annotation_item['category_id'] = category_id
annotation_id += 1
annotation_item['id'] = annotation_id
coco['annotations'].append(annotation_item)
def parseXmlFiles(xml_path):
for f in os.listdir(xml_path):
if not f.endswith('.xml'):
continue
bndbox = dict()
size = dict()
current_image_id = None
current_category_id = None
file_name = None
size['width'] = None
size['height'] = None
size['depth'] = None
xml_file = os.path.join(xml_path, f)
print(xml_file)
tree = ET.parse(xml_file)
root = tree.getroot()
if root.tag != 'annotation':
raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag))
#elem is <folder>, <filename>, <size>, <object>
for elem in root:
current_parent = elem.tag
current_sub = None
object_name = None
if elem.tag == 'folder':
continue
if elem.tag == 'filename':
file_name = elem.text
if file_name in category_set:
raise Exception('file_name duplicated')
#add img item only after parse <size> tag
elif current_image_id is None and file_name is not None and size['width'] is not None:
if file_name not in image_set:
current_image_id = addImgItem(file_name, size)
print('add image with {} and {}'.format(file_name, size))
else:
raise Exception('duplicated image: {}'.format(file_name))
#subelem is <width>, <height>, <depth>, <name>, <bndbox>
for subelem in elem:
bndbox ['xmin'] = None
bndbox ['xmax'] = None
bndbox ['ymin'] = None
bndbox ['ymax'] = None
current_sub = subelem.tag
if current_parent == 'object' and subelem.tag == 'name':
object_name = subelem.text
if object_name not in category_set:
current_category_id = addCatItem(object_name)
else:
current_category_id = category_set[object_name]
elif current_parent == 'size':
if size[subelem.tag] is not None:
raise Exception('xml structure broken at size tag.')
size[subelem.tag] = int(float(subelem.text))
# option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
for option in subelem:
if current_sub == 'bndbox':
if bndbox[option.tag] is not None:
raise Exception('xml structure corrupted at bndbox tag.')
bndbox[option.tag] = int(float(option.text))
# only after parse the <object> tag
if bndbox['xmin'] is not None:
if object_name is None:
raise Exception('xml structure broken at bndbox tag')
if current_image_id is None:
raise Exception('xml structure broken at bndbox tag')
if current_category_id is None:
raise Exception('xml structure broken at bndbox tag')
bbox = []
# x
bbox.append(bndbox['xmin'])
# y
bbox.append(bndbox['ymin'])
# w
bbox.append(bndbox['xmax'] - bndbox['xmin'])
# h
bbox.append(bndbox['ymax'] - bndbox['ymin'])
print('add annotation with {},{},{},{}'.format(object_name, current_image_id, current_category_id,
bbox))
addAnnoItem(object_name, current_image_id, current_category_id, bbox)
if __name__ == '__main__':
xml_path = ' '
json_file = ' '
parseXmlFiles(xml_path)
json.dump(coco, open(json_file, 'w'))
(2)xml2xml.py
#!/usr/bin/env python
# coding:utf-8
import os
import re
import sys
import json
import cv2
import shutil
from PIL import Image
from lxml import etree
import xml.etree.ElementTree as et
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import numpy as np
from lxml.etree import Element, SubElement, tostring
class_name = ["a","b"]
ann_folder1 = ' '
ann_folder2 = ' '
for root, dirs, files in os.walk(ann_folder1, True):
for file in files:
print ('file = ', file)
# open xml
ann_path1 = os.path.join(ann_folder1, file)
tree = et.parse(ann_path1)
root1 = tree.getroot()
# write xml
root2 = Element('annotation')
folder = SubElement(root2, 'folder')
folder.text = 'voc'
filename = SubElement(root2, 'filename')
filename.text = file.replace('.xml', '.jpg')
size = SubElement(root2, 'size')
width = SubElement(size, 'width')
width.text = root1.find('size').find('width').text
height = SubElement(size, 'height')
height.text = root1.find('size').find('height').text
depth = SubElement(size, 'depth')
depth.text = root1.find('size').find('depth').text
obj_num = 0
for obj in root1.findall('object'):
if obj.find('name').text in class_name:
obj_num += 1
object = SubElement(root2, 'object')
name = SubElement(object, 'name')
if obj.find('name').text == 'a' or obj.find('name').text == 'b':
name.text = 'c'
difficult = SubElement(object, 'difficult')
difficult.text = obj.find('difficult').text
bndbox = SubElement(object, 'bndbox')
xmin = SubElement(bndbox, 'xmin')
xmin.text = str(int(float(obj.find('bndbox').find('xmin').text)))
ymin = SubElement(bndbox, 'ymin')
ymin.text = str(int(float(obj.find('bndbox').find('ymin').text)))
xmax = SubElement(bndbox, 'xmax')
xmax.text = str(int(float(obj.find('bndbox').find('xmax').text)))
ymax = SubElement(bndbox, 'ymax')
ymax.text = str(int(float(obj.find('bndbox').find('ymax').text)))
# if obj_name in class_filters:
if obj_num > 0:
ann_path2 = os.path.join(ann_folder2, file)
doc = etree.ElementTree(root2)
doc.write(ann_path2, pretty_print=True)
在data下建立文件夹,并按如下路径存放文件,c/annottations/train.json c/annottations/test.json c/images(该路径存放所有train test图片)
4.开始训练
(1)在src/lib/datasets/dataset/下copy coco.py 新建一个c.py 文件,其具体修改如下图10个地方:
13行改为class c(data.Dataset): c为自己的类名
14行改为自己的类别数(不包括背景类)
15行改为自己的输入图像大小
16-19行可改为自己的均值方差,也可默认,求均值方差的脚本
import cv2, os, argparse
import numpy as np
from tqdm import tqdm
def main():
dirs = r'图片路径' # 修改你自己的图片路径
img_file_names = os.listdir(dirs)
m_list, s_list = [], []
for img_filename in tqdm(img_file_names):
img = cv2.imread(dirs + '/' + img_filename)
img = img / 255.0
m, s = cv2.meanStdDev(img)
m_list.append(m.reshape((3,)))
s_list.append(s.reshape((3,)))
m_array = np.array(m_list)
s_array = np.array(s_list)
m = m_array.mean(axis=0, keepdims=True)
s = s_array.mean(axis=0, keepdims=True)
print("mean = ", m[0][::-1])
print("std = ", s[0][::-1])
if __name__ == '__main__':
main()
22行修改为super(c,self)
23行self.data_dir = os.path.join(opt.data_dir, 'c')
24行修改'{}2017'为'images'
25行修改'test'为'val',因为只用了只转了train.json test.json,没有val.json
28行'instances_extreme_{}2017.json'改为'test.json'
37行'instances_{}2017.json'改为'train.jsojn'
39行改为自己的类别[ '__background__', 'c']
41行self._valid_ids = [1]
(2)修改src/lib/datasets/dataset_factory.py
22行添加自己的类别'c': c
(3)修改 src/lib/opts.py
这两处必须改
15 self.parser.add_argument('--dataset', default='c',
16 help='coco | kitti | coco_hp | pascal | c')
338 'ctdet': {'default_resolution': [320, 320], 'num_classes': 1,
339 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],
340 'dataset': 'c'}
下边修改学习率,修改backbone可改可不改
# model
- self.parser.add_argument('--arch', default='dla_34',
+ self.parser.add_argument('--arch', default='dlav0_34',
# train
- self.parser.add_argument('--lr', type=float, default=1.25e-4,
+ self.parser.add_argument('--lr', type=float, default=1e-3,
- self.parser.add_argument('--lr_step', type=str, default='90,120',
+ self.parser.add_argument('--lr_step', type=str, default='90,150',
- self.parser.add_argument('--num_epochs', type=int, default=140,
+ self.parser.add_argument('--num_epochs', type=int, default=200,
(4)修改src/lib/utils/debugger.py
48 elif num_classes == 1 or dataset == 'c':
49 self.names = c_class_name
444 c_class_name = ["c"]
5.训练指令
1.train
在src路径下执行:
(1)不加载权重:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3
(2)加载权重:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3 --load_model ../models/ctdet_dla_2x.pth
(3)多卡训练+断点恢复:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3 --gpus 0,1,2 --load_model ../models/ctdet_dla_2x.pth --resume
2.test
在src路径下执行:
(1)测试mAP:python test.py ctdet --exp_id c --not_prefetch_test ctdet --load_model ../exp/ctdet/c/model_best.pth
(2)画图查看效果:python demo.py ctdet --demo ../data/c/images --load_model ../exp/ctdet/c/model_best.pth
(3)带数据增强的预测:python demo.py ctdet --demo ../data/c/images --load_model ../exp/ctdet/c/model_best.pth --flip_test
(4)多尺度预测:python demo.py ctdet --demo ../data/c/images --load_model ../exp/ctdet/c/model_best.pth --test_scales 0.5,0.75,1.0,1.25,1.5
注意,如果多尺度预测报错,原因可能是没有编译nms,编译方法:到 path/to/CenterNet/src/lib/externels 目录下,运行:python setup.py build_ext --inplace