基于MxNet实现目标检测-CenterNet【附部分源码及模型】


前言

  本文主要讲解基于mxnet深度学习框架实现目标检测,实现的模型为CenterNet

环境配置:
      python 3.8
      mxnet 1.7.0
      cuda 10.1


目标检测发展史及意义

  图像分类任务的实现可以让我们粗略的知道图像中包含了什么类型的物体,但并不知道物体在图像中哪一个位置,也不知道物体的具体信息,在一些具体的应用场景比如车牌识别、交通违章检测、人脸识别、运动捕捉,单纯的图像分类就不能完全满足我们的需求了。

  这时候,需要引入图像领域另一个重要任务:物体的检测与识别。在传统机器领域,一个典型的案例是利用HOG(Histogram of Gradient)特征来生成各种物体相应的“滤波器”,HOG滤波器能完整的记录物体的边缘和轮廓信息,利用这一滤波器过滤不同图片的不同位置,当输出响应值幅度超过一定阈值,就认为滤波器和图片中的物体匹配程度较高,从而完成了物体的检测。


一、数据集的准备

  首先我是用的是halcon数据集里边的药片,去了前边的100张做标注,后面的300张做测试,其中100张里边选择90张做训练集,10张做验证集。

1.标注工具的安装

pip install labelimg

进入cmd,输入labelimg,会出现如图的标注工具:
在这里插入图片描述

2.数据集的准备

首先我们先创建3个文件夹,如图:
在这里插入图片描述
DataImage:100张需要标注的图像
DataLabel:空文件夹,主要是存放标注文件,这个在labelimg中生成标注文件
test:存放剩下的300张图片,不需要标注
DataImage目录下和test目录的存放样子是这样的(以DataImage为例):
在这里插入图片描述

3.标注数据

  首先我们需要在labelimg中设置图像路径和标签存放路径,如图:
在这里插入图片描述
  然后先记住快捷键:w:开始编辑,a:上一张,d:下一张。这个工具只需要这三个快捷键即可完成工作。
  开始标注工作,首先按下键盘w,这个时候进入编辑框框的模式,然后在图像上绘制框框,输入标签(框框属于什么类别),即可完成物体1的标注,一张物体可以多个标注和多个类别,但是切记不可摸棱两可,比如这张图像对于某物体标注了,另一张图像如果出现同样的就需要标注,或者标签类别不可多个,比如这个图象A物体标注为A标签,下张图的A物体标出成了B标签,最终的效果如图:
在这里插入图片描述
最后标注完成会在DataLabel中看到标注文件,json格式:
在这里插入图片描述

4.解释xml文件的内容

在这里插入图片描述
xml标签文件如图,我们用到的就只有object对象,对其进行解析即可。


二、网络结构的介绍

论文地址:https://arxiv.org/pdf/1904.07850.pdf
网络结构:
在这里插入图片描述
  可以发现CenterNet网络比较简单,主要包括resnet50提取图片特征,然后是反卷积模块Deconv(三个反卷积)对特征图进行上采样,最后三个分支卷积网络用来预测heatmap, 目标的宽高和目标的中心点坐标。值得注意的是反卷积模块,其包括三个反卷积组,每个组都包括一个3*3的卷积和一个反卷积,每次反卷积都会将特征图尺寸放大一倍,有很多代码中会将反卷积前的3x3的卷积替换为DCNv2(Deformable ConvetNets V2)来提高模型拟合能力。

  现有的检测算法都需要对大量的候选位置进行分类,并且还需要后处理(主要是NMS),比如二阶段的Faster RCNN、一阶段的YOLOv3、anchor-free的FCOS、CornerNet等算法,都是如此
  CenterNet是anchor-free的,正样本的分配极其简单,一个目标只唯一对应heatmap上的一个peak,是无需NMS的。


三、代码实现

0.工程目录结构如下

在这里插入图片描述
core:损失计算及一些核心计算的文件都存放在此文件夹
data:数据加载的相关函数及类
net:包含主干网络结构及标准的centernet结构
utils:数据预处理的相关文件
Ctu_CenterNet.py:centernet的训练类和测试类,是整个AI的主入口


1.导入库

import os, time, warnings,json,cv2,colorsys,sys,copy
sys.path.append('.')
import numpy as np
import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet import autograd
from nets.centernet import get_center
from data.data_loader import VOCDetection, VOC07MApMetric
from data.batchify_fn import Tuple, Stack, Pad
from data.data_transform import CenterNetDefaultTrainTransform,CenterNetDefaultValTransform
from core.lr_scheduler import LRScheduler,LRSequential
from core.loss import HeatmapFocalLoss, MaskedL1Loss
from utils.image import *
from PIL import Image,ImageDraw,ImageFont

2.配置GPU/CPU环境

self.ctx = [mx.gpu(int(i)) for i in USEGPU.split(',') if i.strip()]
self.ctx = self.ctx if self.ctx else [mx.cpu()]

3.数据加载器

这里输入的是迭代器,后面都会利用它构建训练的迭代器

class VOCDetection(dataset.Dataset):
    def CreateDataList(self,IMGDir,XMLDir):
        ImgList = os.listdir(IMGDir)
        XmlList = os.listdir(XMLDir)
        classes = []
        dataList=[]
        for each_jpg in ImgList:
            each_xml = each_jpg.split('.')[0] + '.xml'
            if each_xml in XmlList:
                dataList.append([os.path.join(IMGDir,each_jpg),os.path.join(XMLDir,each_xml)])
                with open(os.path.join(XMLDir,each_xml), "r", encoding="utf-8") as in_file:
                    tree = ET.parse(in_file)
                    root = tree.getroot()
                    for obj in root.iter('object'):
                        cls = obj.find('name').text
                        if cls not in classes:
                            classes.append(cls)
        return dataList,classes

    def __init__(self, ImageDir, XMLDir,transform=None):
        self.datalist,self.classes_names = self.CreateDataList(ImageDir,XMLDir)
        self._transform = transform
        self.index_map = dict(zip(self.classes_names, range(len(self.classes_names))))
        # self._label_cache = self._preload_labels()

    @property
    def classes(self):
        return self.classes_names

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, idx):
        img_path = self.datalist[idx][0]
        # label = self._label_cache[idx] if self._label_cache else self._load_label(idx)
        label = self._load_label(idx)
        img = mx.image.imread(img_path, 1)
        if self._transform is not None:
            return self._transform(img, label)
        return img, label.copy()

    def _preload_labels(self):
        return [self._load_label(idx) for idx in range(len(self))]

    def _load_label(self, idx):
        anno_path = self.datalist[idx][1]
        root = ET.parse(anno_path).getroot()
        size = root.find('size')
        width = float(size.find('width').text)
        height = float(size.find('height').text)
        label = []
        for obj in root.iter('object'):
            try:
                difficult = int(obj.find('difficult').text)
            except ValueError:
                difficult = 0
            cls_name = obj.find('name').text.strip().lower()
            if cls_name not in self.classes:
                continue
            cls_id = self.index_map[cls_name]
            xml_box = obj.find('bndbox')
            xmin = (float(xml_box.find('xmin').text) - 1)
            ymin = (float(xml_box.find('ymin').text) - 1)
            xmax = (float(xml_box.find('xmax').text) - 1)
            ymax = (float(xml_box.find('ymax').text) - 1)
            try:
                self._validate_label(xmin, ymin, xmax, ymax, width, height)
                label.append([xmin, ymin, xmax, ymax, cls_id, difficult])
            except AssertionError as e:
                pass
        return np.array(label)

    def _validate_label(self, xmin, ymin, xmax, ymax, width, height):
        assert 0 <= xmin < width, "xmin must in [0, {}), given {}".format(width, xmin)
        assert 0 <= ymin < height, "ymin must in [0, {}), given {}".format(height, ymin)
        assert xmin < xmax <= width, "xmax must in (xmin, {}], given {}".format(width, xmax)
        assert ymin < ymax <= height, "ymax must in (ymin, {}], given {}".format(height, ymax)


4.模型构建

本项目使用resnet、mobilenet、dla做主干网络结构,这里只需要传入backbone即可

self.model = get_center(backbone,self.classes_names,ctx=self.ctx[0],norm_layer=gluon.nn.BatchNorm,use_dcnv2=False)

class CenterNet(nn.HybridBlock):
    def __init__(self, base_network, heads, classes, head_conv_channel=0, scale=4.0, topk=100, flip_test=False, nms_thresh=0, nms_topk=400, post_nms=100, **kwargs):
        if 'norm_layer' in kwargs:
            kwargs.pop('norm_layer')
        if 'norm_kwargs' in kwargs:
            kwargs.pop('norm_kwargs')
        super(CenterNet, self).__init__(**kwargs)
        assert isinstance(heads, OrderedDict), "Expecting heads to be a OrderedDict per head, given {}".format(type(heads))
        self.classes = classes
        self.topk = topk
        self.nms_thresh = nms_thresh
        self.nms_topk = nms_topk
        post_nms = min(post_nms, topk)
        self.post_nms = post_nms
        self.scale = scale
        self.flip_test = flip_test
        self._head_setups = heads
        self._head_conv_channel = head_conv_channel
        with self.name_scope():
            self.base_network = base_network
            self.heatmap_nms = nn.MaxPool2D(pool_size=3, strides=1, padding=1)
            self.decoder = CenterNetDecoder(topk=topk, scale=scale)
            self.heads = nn.HybridSequential('heads')
            for name, values in heads.items():
                head = nn.HybridSequential(name)
                num_output = values['num_output']
                bias = values.get('bias', 0.0)
                weight_initializer = mx.init.Normal(0.001) if bias == 0 else mx.init.Xavier()
                if head_conv_channel > 0:
                    head.add(nn.Conv2D(head_conv_channel, kernel_size=3, padding=1, use_bias=True, weight_initializer=weight_initializer, bias_initializer='zeros'))
                    head.add(nn.Activation('relu'))
                head.add(nn.Conv2D(num_output, kernel_size=1, strides=1, padding=0, use_bias=True, weight_initializer=weight_initializer, bias_initializer=mx.init.Constant(bias)))
                self.heads.add(head)

    @property
    def num_classes(self):
        return len(self.classes)

    def set_nms(self, nms_thresh=0, nms_topk=400, post_nms=100):
        self._clear_cached_op()
        self.nms_thresh = nms_thresh
        self.nms_topk = nms_topk
        post_nms = min(post_nms, self.nms_topk)
        self.post_nms = post_nms

    def reset_class(self, classes, reuse_weights=None):
        self._clear_cached_op()
        old_classes = self.classes
        self.classes = classes
        
        if isinstance(reuse_weights, (dict, list)):
            if isinstance(reuse_weights, dict):
                new_keys = []
                new_vals = []
                for k, v in reuse_weights.items():
                    if isinstance(v, str):
                        try:
                            new_vals.append(old_classes.index(v))
                        except ValueError:
                            raise ValueError("{} not found in old class names {}".format(v, old_classes))
                    else:
                        if v < 0 or v >= len(old_classes):
                            raise ValueError("Index {} out of bounds for old class names".format(v))
                        new_vals.append(v)
                    if isinstance(k, str):
                        try:
                            new_keys.append(self.classes.index(k))
                        except ValueError:
                            raise ValueError("{} not found in new class names {}".format(k, self.classes))
                    else:
                        if k < 0 or k >= len(self.classes):
                            raise ValueError("Index {} out of bounds for new class names".format(k))
                        new_keys.append(k)
                reuse_weights = dict(zip(new_keys, new_vals))
            else:
                new_map = {}
                for x in reuse_weights:
                    try:
                        new_idx = self.classes.index(x)
                        old_idx = old_classes.index(x)
                        new_map[new_idx] = old_idx
                    except ValueError:
                        print("{} not found in old: {} or new class names: {}".format(x, old_classes, self.classes))
                reuse_weights = new_map
        
        with self.name_scope():
            hm_head = nn.HybridSequential('heatmap')
            orig_head = self.heads
            orig_hm = self.heads[0]
            for i in range(len(orig_hm) - 1):
                hm_head.add(orig_hm[i])
            num_output = len(classes)
            bias = self._head_setups['heatmap'].get('bias', 0.0)
            weight_initializer = mx.init.Normal(0.001) if bias == 0 else mx.init.Xavier()

            in_channels = list(orig_hm[0].params.values())[0].shape[1]
            hm_head.add(nn.Conv2D(num_output, kernel_size=1, strides=1, padding=0, use_bias=True, weight_initializer=weight_initializer, bias_initializer=mx.init.Constant(bias), in_channels=in_channels))
            with warnings.catch_warnings(record=True) as _:
                warnings.simplefilter("always")
                ctx = list(orig_hm[0].params.values())[0].list_ctx()
                hm_head.initialize(ctx=ctx)
            if reuse_weights:
                assert isinstance(reuse_weights, dict)
                for old_params, new_params in zip(orig_hm[2].params.values(),
                                                  hm_head[2].params.values()):
                    old_data = old_params.data()
                    new_data = new_params.data()

                    for k, v in reuse_weights.items():
                        if k > len(self.classes) or v > len(old_classes):
                            warnings.warn("reuse mapping {}/{} -> {}/{} out of range".format(
                                k, self.classes, v, old_classes))
                            continue
                        new_data[k::len(self.classes)] = old_data[v::len(old_classes)]

                    new_params.set_data(new_data)
            old_heads = self.heads
            self.heads = nn.HybridSequential('heads')
            self.heads.add(hm_head)
            self.heads.add(orig_head[1])
            self.heads.add(orig_head[2])

    def hybrid_forward(self, F, x):
        y = self.base_network(x)
        out = [head(y) for head in self.heads]
        out[0] = F.sigmoid(out[0])
        if autograd.is_training():
            out[0] = F.clip(out[0], 1e-4, 1 - 1e-4)
            return tuple(out)
        if self.flip_test:
            y_flip = self.base_network(x.flip(axis=3))
            out_flip = [head(y_flip) for head in self.heads]
            out_flip[0] = F.sigmoid(out_flip[0])
            out[0] = (out[0] + out_flip[0].flip(axis=3)) * 0.5
            out[1] = (out[1] + out_flip[1].flip(axis=3)) * 0.5
        heatmap = out[0]
        keep = F.broadcast_equal(self.heatmap_nms(heatmap), heatmap)
        results = self.decoder(keep * heatmap, out[1], out[2])
        return results



5.模型训练

1.学习率设置

lr_steps = sorted([int(ls) for ls in lr_decay_epoch.split(',') if ls.strip()])
lr_decay_epoch = [e for e in lr_steps]

 lr_scheduler = LRSequential([
     LRScheduler('linear', base_lr=0, target_lr=learning_rate,
                 nepochs=0, iters_per_epoch=self.num_samples // self.batch_size),
     LRScheduler(lr_mode, base_lr=learning_rate,
                 nepochs=TrainNum,
                 iters_per_epoch=self.num_samples // self.batch_size,
                 step_epoch=lr_decay_epoch,
                 step_factor=lr_decay, power=2),
 ])

2.优化器设置

if optim == 1:
    trainer = gluon.Trainer(self.model.collect_params(), 'sgd', {'learning_rate': learning_rate, 'wd': 0.0005, 'momentum': 0.9, 'lr_scheduler': lr_scheduler})
elif optim == 2:
    trainer = gluon.Trainer(self.model.collect_params(), 'adagrad', {'learning_rate': learning_rate, 'lr_scheduler': lr_scheduler})
else:
    trainer = gluon.Trainer(self.model.collect_params(), 'adam', {'learning_rate': learning_rate, 'lr_scheduler': lr_scheduler})

3.损失设置

heatmap_loss = HeatmapFocalLoss(from_logits=True)
wh_loss = MaskedL1Loss(weight=0.1)
center_reg_loss = MaskedL1Loss(weight=0.1)
heatmap_loss_metric = mx.metric.Loss('HeatmapFocal')
wh_metric = mx.metric.Loss('WHL1')
center_reg_metric = mx.metric.Loss('CenterRegL1')

4.循环训练

for i, batch in enumerate(self.train_loader):
    split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=self.ctx, batch_axis=0) for ind in range(6)]
    batch_size = self.batch_size
    with autograd.record():
        sum_losses = []
        heatmap_losses = []
        wh_losses = []
        center_reg_losses = []
        wh_preds = []
        center_reg_preds = []
        for x, heatmap_target, wh_target, wh_mask, center_reg_target, center_reg_mask in zip(*split_data):
            heatmap_pred, wh_pred, center_reg_pred = self.model(x)
            wh_preds.append(wh_pred)
            center_reg_preds.append(center_reg_pred)
            wh_losses.append(wh_loss(wh_pred, wh_target, wh_mask))
            center_reg_losses.append(center_reg_loss(center_reg_pred, center_reg_target, center_reg_mask))
            heatmap_losses.append(heatmap_loss(heatmap_pred, heatmap_target))
            curr_loss = heatmap_losses[-1]+ wh_losses[-1] + center_reg_losses[-1]
            sum_losses.append(curr_loss)
        autograd.backward(sum_losses)
    trainer.step(len(sum_losses))

    heatmap_loss_metric.update(0, heatmap_losses)
    wh_metric.update(0, wh_losses)
    center_reg_metric.update(0, center_reg_losses)

    name2, loss2 = wh_metric.get()
    name3, loss3 = center_reg_metric.get()
    name4, loss4 = heatmap_loss_metric.get()
    print('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, LR={}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(epoch, i, batch_size/(time.time()-btic), trainer.learning_rate, name2, loss2, name3, loss3, name4, loss4))
    btic = time.time()

6.模型预测

def predict(self, image, confidence=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    start_time = time.time()
    origin_img = copy.deepcopy(image)
    base_imageSize = origin_img.shape
    image = self.resize_image(image,(self.image_size,self.image_size))
    # print(resize_imageSize,base_imageSize)   # (512, 512, 3) (780, 1248, 3)
    img = nd.array(image)

    # img = resize_short_within(img, self.image_size, max_size)
    img = mx.nd.image.to_tensor(img)
    img = mx.nd.image.normalize(img, mean=mean, std=std)
    x = img.expand_dims(0)

    x = x.as_in_context(self.ctx[0])
    labels, scores, bboxes = [xx[0].asnumpy() for xx in self.model(x)]

    origin_img_pillow = self.cv2_pillow(origin_img)
    font = ImageFont.truetype(font='./model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(origin_img_pillow)[1] + 0.5).astype('int32'))
    thickness = max((np.shape(origin_img_pillow)[0] + np.shape(origin_img_pillow)[1]) // self.image_size, 1)

    imgbox = []
    for i, bbox in enumerate(bboxes):
        if (scores is not None and scores.flat[i] < confidence) or labels is not None and labels.flat[i] < 0:
            continue
        cls_id = int(labels.flat[i]) if labels is not None else -1

        xmin, ymin, xmax, ymax = [int(x) for x in bbox]
        xmin, ymin, xmax, ymax = xmin/self.image_size, ymin/self.image_size, xmax/self.image_size, ymax/self.image_size
        box_xy, box_wh = np.array([(xmin+xmax)/2,(ymin+ymax)/2]).astype('float32'), np.array([xmax-xmin,ymax-ymin]).astype('float32')
        image_shape = np.array((base_imageSize[0],base_imageSize[1]))
        input_shape = np.array((self.image_size,self.image_size))
        result = self.correct_boxes(box_xy, box_wh, input_shape, image_shape,True)
        ymin, xmin, ymax, xmax = result

        xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
        class_name = self.classes_names[cls_id]
        score = '{:d}%'.format(int(scores.flat[i] * 100)) if scores is not None else ''
        imgbox.append([(xmin, ymin, xmax, ymax), cls_id, self.classes_names[cls_id], score])
        top, left, bottom, right = ymin, xmin, ymax, xmax


        # cv2.rectangle(origin_img, (xmin, ymin), (xmax, ymax), self.colors[cls_id], 2)
        # if class_name or score:
        #     y = ymin - 15 if ymin - 15 > 15 else ymin + 15
        #     cv2.putText(origin_img, '{:s} {:s}'.format(class_name, score),
        #                 (xmin, y), cv2.FONT_HERSHEY_SIMPLEX, min(1.0 / 2, 2),
        #                 self.colors[cls_id], min(int(1.0), 5), lineType=cv2.LINE_AA)
        label = '{}-{}'.format(class_name, score)
        draw = ImageDraw.Draw(origin_img_pillow)
        label_size = draw.textsize(label, font)
        label = label.encode('utf-8')

        if top - label_size[1] >= 0:
            text_origin = np.array([left, top - label_size[1]])
        else:
            text_origin = np.array([left, top + 1])

        for i in range(thickness):
            draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[cls_id])
        draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[cls_id])
        draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
        del draw

    result_value = {
        "image_result": self.pillow_cv2(origin_img_pillow),
        "bbox": imgbox,
        "time": (time.time() - start_time) * 1000
    }

    return result_value

四、算法主入口

if __name__ == '__main__':
    ctu = Ctu_CenterNet(USEGPU='0', image_size=512)
    ctu.InitModel(DataDir=r'D:/Ctu/Ctu_Project_DL/DataSet/DataSet_Detection_YaoPian',batch_size=1,Pre_Model = None, num_workers = 0,backbone='resnet18')
    ctu.train(TrainNum=500, learning_rate=0.001, lr_decay_epoch='50,100,150,200', lr_decay=0.9, ModelPath='./Model',optim=0)

    # ctu = Ctu_CenterNet(USEGPU='0')
    # ctu.LoadModel(r'D:/Ctu/Ctu_Project_DL/Ctu_Mxnet_DL/Detection/Ctu_Detection/Ctu_CenterNet/Model_centerNet_resnet50')
    # cv2.namedWindow("result", 0)
    # cv2.resizeWindow("result", 640, 480)
    # index = 0
    # for root, dirs, files in os.walk(r'D:/Ctu/Ctu_Project_DL/DataSet/DataSet_Detection_YaoPian/test'):
    #     for f in files:
    #         img_cv = ctu.read_image(os.path.join(root, f))
    #         if img_cv is None:
    #             continue
    #         res = ctu.predict(img_cv, 0.5)
    #         for each in res['bbox']:
    #             print(each)
    #         print("耗时:" + str(res['time']) + ' ms')
    #         # cv2.imwrite(str(index + 1)+'.bmp',res['image_result'])
    #         cv2.imshow("result", res['image_result'])
    #         cv2.waitKey()
    #         # index +=1

五、训练效果展示

备注:项目模型的本人没有保存因此会后续提供训练效果
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

爱学习的广东仔

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值