centernet训练自己的数据集、后处理解析

代码链接:
https://github.com/xingyizhou/CenterNet/tree/master/readme
这个工程的环境配置起来很费劲,根据教程配好环境,py=3.6,torch=0.4.1,运行的时候报错
ImportError:/home/shiep/CenterNet/src/lib/models/networks/DCNs/_ext/dcn_v2/dcn_v2.so: undefined symbol: __cudaRegisterFatBinaryEnd
解决方法:驱动版本太高,重装驱动太麻烦。于是找到了另一个工程:
https://github.com/shenyi0220/centernet-cp-cluster
这个在我之前配好的环境里正常运行。
一、处理数据集
该工程用到的是coco数据集,要把xml的转成json

# coding:utf-8
# 运行前请先做以下工作:
# pip install lxml
# 将所有的图片及xml文件存放到xml_dir指定的文件夹下,并将此文件夹放置到当前目录下
#

import os
import glob
import json
import shutil

import cv2
import numpy as np
import xml.etree.ElementTree as ET

START_BOUNDING_BOX_ID = 1
save_path = "."

names = ["aa", "bb", "cc"]
txt_path = "labels/"
img_path="images/"

def get(root, name):
    return root.findall(name)


def get_and_check(root, name, length):
    vars = get(root, name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars

def convert(xml_list, json_file):
    json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
    categories = pre_define_categories.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    for index, fi in enumerate(xml_list):
        print("Processing %s"%(fi))
        xml_f = fi
        filename = os.path.basename(xml_f)[:-4] + ".jpg"
        image_id = 20190000001 + index
        try:
            img=cv2.imread(img_path+filename)
            height, width, channel = img.shape
        except:
            try:
                filename = os.path.basename(xml_f)[:-4] + ".png"
                img=cv2.imread(img_path+filename)
                height, width, channel = img.shape
            except:
                filename = os.path.basename(xml_f)[:-4] + ".jpeg"
                img=cv2.imread(img_path+filename)
                height, width, channel = img.shape
        image = {'file_name': filename, 'height': height, 'width': width, 'id': image_id}
        json_dict['images'].append(image)
        for line in open(xml_f):
            line_all=line.split(" ")
            category = names[int(line_all[0])]
        # for obj in get(root, 'object'):
            # category = get_and_check(obj, 'name', 1).text
            if category in all_categories:
                all_categories[category] += 1
            else:
                all_categories[category] = 1

            category_id = int(line_all[0])
            xmin = int(float(line_all[1])*width-float(line_all[3])*width/2)
            ymin = int(float(line_all[2])*height-float(line_all[4])*height/2)
            xmax = int(float(line_all[1])*width+float(line_all[3])*width/2)
            ymax = int(float(line_all[2])*height+float(line_all[4])*height/2)
            if xmax < xmin or ymax < ymin:
                continue
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id':
                image_id, 'bbox': [xmin, ymin, o_width, o_height],
                   'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                   'segmentation': []}
            json_dict['annotations'].append(ann)
            bnd_id = bnd_id + 1

    for cate, cid in categories.items():
        cat = {'supercategory': 'ball', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()
    print("------------create {} done--------------".format(json_file))
    print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories),
                                                                                  all_categories.keys(),
                                                                                  len(pre_define_categories),
                                                                                  pre_define_categories.keys()))
    print("category: id --> {}".format(categories))
    print(categories.keys())
    print(categories.values())


if __name__ == '__main__':
    # 定义你自己的类别

    pre_define_categories = {}
    for i, cls in enumerate(names):
        pre_define_categories[cls] = i + 1
    # 这里也可以自定义类别id,把上面的注释掉换成下面这行即可
    # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
    only_care_pre_define_categories = True  # or False

    # 保存的json文件
    save_json_train = 'train_ship.json'
    save_json_val = 'val_ship.json'
    save_json_test = 'test_ship.json'

    # 初始文件所在的路径
    xml_dir = "/media/zhanglu/0bb0d537-0b35-45bf-94ec-9def4a6dd599/zhanglu/yolov5-fishi/data/ship/data_0519/labels"
    xml_list = glob.glob(xml_dir + "/*.txt")
    xml_list = np.sort(xml_list)

    # 打乱数据集
    np.random.seed(100)
    np.random.shuffle(xml_list)

    # 按比例划分打乱后的数据集
    train_ratio = 0.8
    val_ratio = 0.1
    train_num = int(len(xml_list) * train_ratio)
    val_num = int(len(xml_list) * val_ratio)
    xml_list_train = xml_list[:train_num]
    xml_list_val = xml_list[train_num: train_num + val_num]
    xml_list_test = xml_list[train_num + val_num:]

    # 将xml文件转为coco文件,在指定目录下生成三个json文件(train/test/food)
    convert(xml_list_train, save_json_train)
    convert(xml_list_val, save_json_val)
    convert(xml_list_test, save_json_test)

    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))
    print("test number:", len(xml_list_val))

二、训练自己的数据集
参考:
https://blog.csdn.net/qq_41613251/article/details/114446107
(1)在src/lib/datasets/dataset/下copy coco.py 新建一个c.py 文件,其具体修改如下图10个地方:

13行改为class c(data.Dataset): c为自己的类名

14行改为自己的类别数(不包括背景类)

15行改为自己的输入图像大小

16-19行可改为自己的均值方差,也可默认,求均值方差的脚本

22行修改为super(c,self)

23行self.data_dir = os.path.join(opt.data_dir, ‘c’)

24行修改’{}2017’为’images’

25行修改’test’为’val’,因为只用了只转了train.json test.json,没有val.json

28行’instances_extreme_{}2017.json’改为’test.json’

37行’instances_{}2017.json’改为’train.jsojn’

39行改为自己的类别[ ‘background’, ‘c’]

41行self._valid_ids = [1]

(2)修改src/lib/datasets/dataset_factory.py

22行添加自己的类别’c’: c

(3)修改 src/lib/opts.py

这两处必须改

15     self.parser.add_argument('--dataset', default='c',                                
16                              help='coco | kitti | coco_hp | pascal | c')

338       'ctdet': {'default_resolution': [320, 320], 'num_classes': 1,                        
339                 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278], 
340                 'dataset': 'c'}

下边修改学习率,修改backbone可改可不改

(4)修改src/lib/utils/debugger.py

48     elif num_classes == 1 or dataset == 'c':                                          
49       self.names = c_class_name

444 c_class_name = ["c"]

使用nms时有个报错,在external编译完还是找不到文件。
from external.nms import soft_nms
ModuleNotFoundError: No module named ‘external.nms’
解决办法:
把nms.cpython-36m-x86_64-linux-gnu.so改成nms.so
三、nms和softnms的区别:
nms可以直接调接口

import torch
import torchvision
 
box =  torch.tensor([[2,3.1,7,5],[3,4,8,4.8],[4,4,5.6,7],[0.1,0,8,1]]) 
score = torch.tensor([0.5, 0.3, 0.2, 0.4])
 
output = torchvision.ops.nms(boxes=box, scores=score, iou_threshold=0.3)
print('IOU of bboxes:')
iou = torchvision.ops.box_iou(box,box)
print(iou)
print(output)

四、后处理解析
ret=detector.run(img_name)
results[img_id]=ret[“results”]

ret[“results”]是dict:9
(9,5)
(18,5)
(5,5)
(30,5)
(16,5)
(4,5)
(10,5)
(4,5)
(4,5)

五、过一段时间不用,再次使用时报错

  File "centernet/centernet-cp-cluster-my/src/lib/models/model.py", line 12, in <module>
    from .networks.pose_dla_dcn import get_pose_net as get_dla_dcn
  File "centernet/centernet-cp-cluster-my/src/lib/models/networks/pose_dla_dcn.py", line 16, in <module>
    from .DCNv2.dcn_v2 import DCN
  File "centernet/centernet-cp-cluster-my/src/lib/models/networks/DCNv2/dcn_v2.py", line 12, in <module>
    import _ext as _backend
ModuleNotFoundError: No module named '_ext'

解决方法:重新编译

centernet/centernet-cp-cluster-my/src/lib/models/networks/DCNv2$ ./make.sh
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值