使用CenterNet做目标检测

前言

项目地址:https://github.com/zzubqh/CenterNetDect-pytorch.git
之前写过一篇CenterNet源码结构解析,能看到官方的源码结构有点复杂,虽然已经剖析过了但是真到实际应用中查找细节的时候还是有点绕,虽然文章中也列出了另外一个简单实现,但是个人感觉结构依然不是太清晰,所以基于这个简单实现的源码进行了重构,主要修改的地方:

  1. 将数据增强部分封装成了一个类,基于imgaug包实现;
  2. 去掉了基于coco数据集的依赖,改成解析自己的数据集,数据集的格式变得很简单,格式如下:
    “文件路径 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id"
    F:\code\Data\DectDataset\case_006/coronal/245.jpg 762,112,850,195,3 756,292,843,366,3
    F:\code\Data\DectDataset\case_024/sagittal/69.jpg 736,139,817,212,3
    F:\code\Data\DectDataset\case_024/sagittal/67.jpg 737,136,822,213,3
    F:\code\Data\DectDataset\case_012/sagittal/179.jpg 834,95,912,161,3
    
    当然也可以直接写文件名而不是文件路径,在dataset类中再拼接成完整路径也可以
  3. 重写训练代码,改成了一个类,不再将训练函数和验证函数放在一起而是封装到了一个类中
  4. 重写了验证函数,使得验证后给出的结果更加直观明了
  5. 去掉了pose预测的代码,只专注于目标检测
  6. 修正了原项目中的一些错误,比如在hourglass.py源文件中有句代码“if self.training or ind == self.nstack - 1:”,但是整个类是没有self.training这个属性的,还有在get_hourglass中,源作者没有将类的个数即num_classess这个变量传入,所以整个工程默认80个类,即便在train.py中给了num_classess的值在构造hourglass实例的时候依然是80个类
  7. 基本上除了保留了核心的hourglass.py网络结构以外全部进行了重写,但依然很感谢https://github.com/zzzxxxttt/pytorch_simple_CenterNet_45作者的辛苦,表示敬意!

训练自己的数据

网络架构配置在config.py的cfg.arch =‘large_hourglass’ 中,名称只支持large_hourglass和small_hourglass,在large_hourglass模式下,显存至少8G以上。

数据准备

使用你自己熟悉的打标工具进行标注,如果使用的是labelme 4.5.6版本标注的,可以使用以下代码转成上面的数据格式

import os
import json

def create_image_dataset():
    """     
    输出label.txt格式: image_path x1,y1,x2,y2,class_id1 x1,y1,x2,y2,class_id2
    """
    annotation_file = r'bone_annotation.md'  # 输出的label文件
    input_root_dir = r'F:\code\Data\DectDataset'    
    jpeg_dir = [os.path.join(input_root_dir, 'img_dir']

    with open(annotation_file, 'w', encoding='utf-8') as wf:
        for child_dir in jpeg_dir:
            json_files = glob.glob(child_dir + '/*.json')
            print('load {0} json files success!'.format(child_dir))
            json_values = [json_pares(json_file) for json_file in json_files]
            labeled_img = {item['img_name']: {'bbox': item['bbox'], 'label': item['label']} for item in json_values}
            img_files = glob.glob(child_dir + '/*.jpg')
            for img_path in tqdm.tqdm(img_files):
                img_name = os.path.basename(img_path)
                bbox = []
                label = []
                key_name = img_name

                if key_name in labeled_img.keys():
                    bbox = labeled_img[key_name]['bbox']
                    label = labeled_img[key_name]['label']

                line_str = img_path
                for box_index, rect in enumerate(bbox):
                    item_str = ' {0},{1},{2},{3},{4}'.format(rect[0, 0], rect[0, 1], rect[1, 0], rect[1, 1], label[box_index])
                    line_str += item_str
                wf.write(line_str + '\r')


def json_pares(json_file):
    value_data = dict()
    with open(json_file, 'r', encoding='utf-8') as rf:
        json_data = json.load(rf)
        value_data.setdefault('img_name', json_data['imagePath'].replace('png', 'jpg'))
        value_data.setdefault('bbox', [])
        value_data.setdefault('label', [])
        for shape in json_data['shapes']:
            value_data['label'].append(shape['label'].lower())
            p1 = shape['points'][0]
            p2 = shape['points'][1]
            bbox = np.array([[p1[0], p1[1]], [p2[0], p2[1]]], dtype=np.int)
            value_data['bbox'].append(bbox)
    return value_data

注意:由于labelme中的label是文字描述的,而代码中需要使用label对应的序号,这个地方需要注意一下,可以在代码中直接改掉,也可以转换好后通过“查找-替换”来做,请自行搞定。

代码修改

  1. 在dataset.py文件中,找到self.max_objs,这里是一张图片中最多出现几个object,比如你的数据集中一张图片中最多要检测100个目标,则这里改成100,根据实际情况修改;将class_names = []修改成你的类别的名称
  2. 在config.py文件中,cfg.net_input_size是网络的输入图片尺寸,按[w, h]格式输入,我的是长方形的图片所以我填的是[512, 256];还有一个是cfg.num_classes修改成你数据集的类别总数,后面加1是因为背景类,所以num_class = 你的类别总数 + 1
  3. dataset.py文件中,有个get_image_id的函数,这里的作用是为了后面的验证时使用的id,只要保证不同的图片不一样的值就行了,我只是给出了一个例子,也可以直接使用UUID来替代

训练

根据实际情况修改好上面的内容后,开始训练

python train.py

训练过程:
在这里插入图片描述

测试

修改detect.py文件,分别将weights_file和img_path改成你模型的保存路径和测试图片的路径即可。
运行:

python detect.py

会自动将检测的结果显示出来
在这里插入图片描述

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值