Centernet训练自己的数据集

1.代码地址:

https://github.com/xingyizhou/CenterNet

2.编译

按照 https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md 顺序安装

 (1)修改/anaconda3/envs/CenterNet/lib/python3.6/site-packages/torch/nn下的torch/nn/functional.py文件找到torch.batch_norm,替换torch.backends.cudnn.enabled with False

 (2)torch要装1.0版本,

conda install pytorch==1.0.0 torchvision==0.2.1 cuda100 -c pytorch

(3)git clone https://github.com/CharlesShang/DCNv2到$CenterNet_ROOT/src/lib/models/networks/DCNv2,不需要checkout到0.4版本

./make.sh

3.自己数据集格式

训练使用的是json文件,json文件制作脚本,其中,xml2xml.py是为了将不同的类转为同一类,或者在标注了很多类的xml文件中提取出自己要训练的类别,生成新的xml文件。

(1)xml2json.py

import xml.etree.ElementTree as ET
import os
import json
 
coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []
 
category_set = dict()
image_set = set()
 
category_item_id = 0
image_id = 20200000000
annotation_id = 0
 
def addCatItem(name):
    global category_item_id
    category_item = dict()
    category_item['supercategory'] = 'none'
    category_item_id += 1
    category_item['id'] = category_item_id
    category_item['name'] = name
    coco['categories'].append(category_item)
    category_set[name] = category_item_id
    return category_item_id
 
def addImgItem(file_name, size):
    global image_id
    if file_name is None:
        raise Exception('Could not find filename tag in xml file.')
    if size['width'] is None:
        raise Exception('Could not find width tag in xml file.')
    if size['height'] is None:
        raise Exception('Could not find height tag in xml file.')
    image_id += 1
    image_item = dict()
    image_item['id'] = image_id
    image_item['file_name'] = file_name
    image_item['width'] = size['width']
    image_item['height'] = size['height']
    coco['images'].append(image_item)
    image_set.add(file_name)
    return image_id
 
def addAnnoItem(object_name, image_id, category_id, bbox):
    global annotation_id
    annotation_item = dict()
    annotation_item['segmentation'] = []
    seg = []
    #bbox[] is x,y,w,h
    #left_top
    seg.append(bbox[0])
    seg.append(bbox[1])
    #left_bottom
    seg.append(bbox[0])
    seg.append(bbox[1] + bbox[3])
    #right_bottom
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1] + bbox[3])
    #right_top
    seg.append(bbox[0] + bbox[2])
    seg.append(bbox[1])
 
    annotation_item['segmentation'].append(seg)
 
    annotation_item['area'] = bbox[2] * bbox[3]
    annotation_item['iscrowd'] = 0
    annotation_item['ignore'] = 0
    annotation_item['image_id'] = image_id
    annotation_item['bbox'] = bbox
    annotation_item['category_id'] = category_id
    annotation_id += 1
    annotation_item['id'] = annotation_id
    coco['annotations'].append(annotation_item)
 
def parseXmlFiles(xml_path):
    for f in os.listdir(xml_path):
        if not f.endswith('.xml'):
            continue
        
        bndbox = dict()
        size = dict()
        current_image_id = None
        current_category_id = None
        file_name = None
        size['width'] = None
        size['height'] = None
        size['depth'] = None
 
        xml_file = os.path.join(xml_path, f)
        print(xml_file)
 
        tree = ET.parse(xml_file)
        root = tree.getroot()
        if root.tag != 'annotation':
            raise Exception('pascal voc xml root element should be annotation, rather than {}'.format(root.tag))
 
        #elem is <folder>, <filename>, <size>, <object>
        for elem in root:
            current_parent = elem.tag
            current_sub = None
            object_name = None
            
            if elem.tag == 'folder':
                continue
            
            if elem.tag == 'filename':

                file_name = elem.text
                if file_name in category_set:
                    raise Exception('file_name duplicated')
                
            #add img item only after parse <size> tag
            elif current_image_id is None and file_name is not None and size['width'] is not None:
                if file_name not in image_set:
                    current_image_id = addImgItem(file_name, size)
                    print('add image with {} and {}'.format(file_name, size))
                else:
                    raise Exception('duplicated image: {}'.format(file_name))
            #subelem is <width>, <height>, <depth>, <name>, <bndbox>
            for subelem in elem:
                bndbox ['xmin'] = None
                bndbox ['xmax'] = None
                bndbox ['ymin'] = None
                bndbox ['ymax'] = None

                current_sub = subelem.tag
                if current_parent == 'object' and subelem.tag == 'name':
                    object_name = subelem.text
                    if object_name not in category_set:
                        current_category_id = addCatItem(object_name)
                    else:
                        current_category_id = category_set[object_name]

                elif current_parent == 'size':
                    if size[subelem.tag] is not None:
                        raise Exception('xml structure broken at size tag.')
                    size[subelem.tag] = int(float(subelem.text))

                # option is <xmin>, <ymin>, <xmax>, <ymax>, when subelem is <bndbox>
                for option in subelem:
                    if current_sub == 'bndbox':
                        if bndbox[option.tag] is not None:
                            raise Exception('xml structure corrupted at bndbox tag.')
                        bndbox[option.tag] = int(float(option.text))

                # only after parse the <object> tag
                if bndbox['xmin'] is not None:
                    if object_name is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_image_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    if current_category_id is None:
                        raise Exception('xml structure broken at bndbox tag')
                    bbox = []
                    # x
                    bbox.append(bndbox['xmin'])
                    # y
                    bbox.append(bndbox['ymin'])
                    # w
                    bbox.append(bndbox['xmax'] - bndbox['xmin'])
                    # h
                    bbox.append(bndbox['ymax'] - bndbox['ymin'])
                    print('add annotation with {},{},{},{}'.format(object_name, current_image_id, current_category_id,
                                                                   bbox))
                    addAnnoItem(object_name, current_image_id, current_category_id, bbox)
 
if __name__ == '__main__':
    xml_path = ' '
    json_file = ' '
    parseXmlFiles(xml_path)
    json.dump(coco, open(json_file, 'w'))

(2)xml2xml.py

#!/usr/bin/env python
# coding:utf-8

import  os
import re
import sys
import json
import cv2
import shutil
from PIL import Image
from lxml import etree
import xml.etree.ElementTree as et
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
import numpy as np
from lxml.etree import Element, SubElement, tostring

class_name = ["a","b"]

ann_folder1 = ' '
ann_folder2 = ' '


for root, dirs, files in os.walk(ann_folder1, True):
    for file in files:
        print ('file = ', file)

        # open xml
        ann_path1 = os.path.join(ann_folder1, file)
        tree = et.parse(ann_path1)
        root1 = tree.getroot()

        # write xml
        root2 = Element('annotation')
        folder = SubElement(root2, 'folder')
        folder.text = 'voc'

        filename = SubElement(root2, 'filename')
        filename.text = file.replace('.xml', '.jpg')

        size = SubElement(root2, 'size')
        width = SubElement(size, 'width')
        width.text = root1.find('size').find('width').text
        height = SubElement(size, 'height')
        height.text = root1.find('size').find('height').text
        depth = SubElement(size, 'depth')
        depth.text = root1.find('size').find('depth').text

        obj_num = 0
        for obj in root1.findall('object'):
            if obj.find('name').text in class_name:
                obj_num += 1
                object = SubElement(root2, 'object')

                name = SubElement(object, 'name')
                if obj.find('name').text == 'a' or obj.find('name').text == 'b':
            name.text = 'c'
                difficult = SubElement(object, 'difficult')
                difficult.text = obj.find('difficult').text

                bndbox = SubElement(object, 'bndbox')
                xmin = SubElement(bndbox, 'xmin')
                xmin.text = str(int(float(obj.find('bndbox').find('xmin').text)))
                ymin = SubElement(bndbox, 'ymin')
                ymin.text = str(int(float(obj.find('bndbox').find('ymin').text)))
                xmax = SubElement(bndbox, 'xmax')
                xmax.text = str(int(float(obj.find('bndbox').find('xmax').text)))
                ymax = SubElement(bndbox, 'ymax')
                ymax.text = str(int(float(obj.find('bndbox').find('ymax').text)))

                # if obj_name in class_filters:

        if obj_num > 0:
            ann_path2 = os.path.join(ann_folder2, file)
            doc = etree.ElementTree(root2)
            doc.write(ann_path2, pretty_print=True)

在data下建立文件夹,并按如下路径存放文件,c/annottations/train.json   c/annottations/test.json c/images(该路径存放所有train test图片)

4.开始训练

(1)在src/lib/datasets/dataset/下copy coco.py 新建一个c.py 文件,其具体修改如下图10个地方:

13行改为class c(data.Dataset): c为自己的类名

14行改为自己的类别数(不包括背景类)

15行改为自己的输入图像大小

16-19行可改为自己的均值方差,也可默认,求均值方差的脚本

import cv2, os, argparse
import numpy as np
from tqdm import tqdm
def main():
    dirs = r'图片路径'  # 修改你自己的图片路径
    img_file_names = os.listdir(dirs)
    m_list, s_list = [], []
    for img_filename in tqdm(img_file_names):
        img = cv2.imread(dirs + '/' + img_filename)
        img = img / 255.0
        m, s = cv2.meanStdDev(img)
        m_list.append(m.reshape((3,)))
        s_list.append(s.reshape((3,)))
    m_array = np.array(m_list)
    s_array = np.array(s_list)
    m = m_array.mean(axis=0, keepdims=True)
    s = s_array.mean(axis=0, keepdims=True)
    print("mean = ", m[0][::-1])
    print("std = ", s[0][::-1])


if __name__ == '__main__':
    main()
 

22行修改为super(c,self)

23行self.data_dir = os.path.join(opt.data_dir, 'c')

24行修改'{}2017'为'images'

25行修改'test'为'val',因为只用了只转了train.json test.json,没有val.json

28行'instances_extreme_{}2017.json'改为'test.json'

37行'instances_{}2017.json'改为'train.jsojn'

39行改为自己的类别[ '__background__', 'c']

41行self._valid_ids = [1]

(2)修改src/lib/datasets/dataset_factory.py

22行添加自己的类别'c': c

(3)修改 src/lib/opts.py

这两处必须改

15     self.parser.add_argument('--dataset', default='c',                                
16                              help='coco | kitti | coco_hp | pascal | c')

338       'ctdet': {'default_resolution': [320, 320], 'num_classes': 1,                        
339                 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278], 
340                 'dataset': 'c'}

下边修改学习率,修改backbone可改可不改

# model                                                                                  
-    self.parser.add_argument('--arch', default='dla_34',                                     
+    self.parser.add_argument('--arch', default='dlav0_34',   

# train

-    self.parser.add_argument('--lr', type=float, default=1.25e-4,                            
+    self.parser.add_argument('--lr', type=float, default=1e-3,   

-    self.parser.add_argument('--lr_step', type=str, default='90,120',                    
+    self.parser.add_argument('--lr_step', type=str, default='90,150', 

-    self.parser.add_argument('--num_epochs', type=int, default=140,                          
+    self.parser.add_argument('--num_epochs', type=int, default=200,

(4)修改src/lib/utils/debugger.py

48     elif num_classes == 1 or dataset == 'c':                                          
49       self.names = c_class_name

444 c_class_name = ["c"]

5.训练指令

1.train 

在src路径下执行:

(1)不加载权重:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3

(2)加载权重:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3 --load_model ../models/ctdet_dla_2x.pth

(3)多卡训练+断点恢复:python main.py ctdet --exp_id c --batch_size 64 --lr 2e-3 --gpus 0,1,2 --load_model ../models/ctdet_dla_2x.pth --resume

2.test

在src路径下执行:

(1)测试mAP:python test.py ctdet --exp_id c   --not_prefetch_test   ctdet  --load_model ../exp/ctdet/c/model_best.pth

(2)画图查看效果:python demo.py ctdet --demo  ../data/c/images   --load_model ../exp/ctdet/c/model_best.pth

(3)带数据增强的预测:python demo.py ctdet --demo  ../data/c/images   --load_model ../exp/ctdet/c/model_best.pth --flip_test

(4)多尺度预测:python demo.py ctdet --demo  ../data/c/images   --load_model ../exp/ctdet/c/model_best.pth --test_scales 0.5,0.75,1.0,1.25,1.5

   注意,如果多尺度预测报错,原因可能是没有编译nms,编译方法:到 path/to/CenterNet/src/lib/externels 目录下,运行:python setup.py build_ext --inplace

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
要使用CenterNet训练自己的数据集,你需要进行以下步骤: 1. 删除之前训练过程中生成的缓存文件。如果你之前使用了coco数据集测试了模型,需要删除CenterNet-master/cache/coco_minival2014.pkl文件。这是因为在第一次运行时,代码会将coco数据集的instances转换为模型所需的格式,并在下一次使用时直接读取。如果你没有训练过coco数据集,可以忽略这一步。\[1\] 2. 修改参数。根据你要训练的模型选择对应的文件,比如models/CenterNet-52.py或models/CenterNet-104.py。在文件中找到第132行,将out_dim的值从80修改为你自己数据集的类别数目。\[2\] 3. 将数据集分成训练集和验证集。将图片文件夹重命名为trainval2014和minival2014,并放置在CenterNet-master/data/coco/images目录下。将对应的json文件命名为instances_trainval2014.json和instances_minival2014.json,并放置在CenterNet-master/data/coco/annotations目录下。\[3\] 完成以上步骤后,你就可以使用CenterNet训练自己的数据集了。 #### 引用[.reference_title] - *1* *2* *3* [CenterNet 训练自己的数据集](https://blog.csdn.net/surserrr/article/details/100153886)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值