TensorRT+CUDA加速优化版CenterNet旋转目标以及水平目标框的检测

51 篇文章 7 订阅
3 篇文章 2 订阅

前言

由于工作项目所需,一直用centerNet做旋转目标检测,在实际产品或者工业应用上落地此检测算法,那么在足够的算力下, 更好优选的方式还是需要c/c++来部署实现。
那么CenterNet也带来一个问题,那就是部署不太容易,主要是两个方面:

  • 主流实现大多不好支持onnx导出;
  • 后处理与传统的检测算法不太一样,比如nms,CenterNet用的实际上是一个3x3的maxpooling。
  • 此处还涉及到一点就是,原版的centerNet并可以检测旋转目标,所以此处就涉及到一个角度问题。

首选是将centerNet导出为onnx模型,之后再使用TensorRT对CenterNet进行加速。
本篇文章不仅仅只是告诉大家,CenterNet可以用TensorRT+cuda加速,同时也想告诉大家整个过程大概是怎么样的,最终我也会将我的整个工程代码链接放出来。(本文需要支持导出onnx的CenterNet版本),cuda编程入门篇见我的另一个博文哦。[借一个栗子来讲讲解基于C的cuda编程)

流程:

  1. 使用pytorch训练模型,生成.pth文件;
  2. 将*.pth转换成onnx模型;
  3. 在tensorrt中加载onnx模型,并转换成trtde object;
  4. 在trt中使用第三步转换成得object进行推理。

一、关于CenterNet后处理的思考

最终用TensorRT重构centernet,由于python代码太过于复杂,单前处理就花了很长时间。现在前处理基本上没问题,网络已经可以在tensorrt上进行forward,但是真正复杂的地方才刚开始。(注意下面的180128*128中的80代表类别,128也是相对于512的输入,down=4的来说的,如果你的输入size和类别不同,需要根据自己的实际情况来调整)
具体来说,centernet的后处理分为两个部分

  • 对 1x80x128x128这个feature进行3x3的maxpool,其实就等于是我们传统的进行nms一样;
  • 从1x80x128x128里面根据热力图峰值点,找到我们需要的前100个目标,这里面有一个很重要的操作叫做TopK,这个topk就是将张量的最大的几个值返回,同时返回对应的index。然后根据index找到对应的中心点和偏移,对应的最终框也就找到了。
    比较棘手的问题也是两个:
  • 最后的这一步maxpool和topk为什么不能放到模型里面去?而是需要后处理来算?这机加大了复杂度,同时有很浪费时间。
    思考如下:这里的maxpool其实事可以放到模型里面去的,但是那样的化, 你网络出来的东西就比较多了,要出来1x80x100 同时还要对应的index。这些张量的操作放到C++里面怎么都不太好做。

主要是onnx到TensorRT的两个问题

  • 后处理本质上是支持把这些操作同时trace到onnx模型的,它甚至不需要单独的写Plugin,因为它只需要两个onnx的op,即maxpool和topk;
  • 但问题是这两个op,在TensorRT中并没有相应的支持。

最终我们选择onnx的导出方式是,只导出网络部分,对于后处理部分,我们再单独的写layer来支持它。不过在最开始这一步是很难完成的,原因主要是:

  • 后处理有点复杂,虽然从原理上来讲,只是一个对
    1x80x128x128的特征热力图的maxpool,但是自己写代码就有点繁杂了,最关键的是速度可能会很慢;
    直接从网络拿到输出,再自己写代码后处理会涉及到一个topK的问题,即找到张量里面特定维度的最大值的index,这个要自己手鲁代码估计会非常非常的慢。

好在最近有朋友放出一个CUDA Kernel很好的解决了这个问题,可以说这个kernel就是本文实现CenterNet TensorRT加速的核心。这个kernel的核心逻辑,其实就是解决了上面提到的后处理的问题。这部分CUDA代码贴出来其实也十分的简单:
这部分CUDA代码贴出来其实也十分的简单:

__global__ void Cdetforward_kernel(const float *hm, const float *reg,const float *wh ,
        float *output,const int w,const int h,const int classes,const int kernerl_size,const float visthresh  ) {
    int idx = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
    if (idx >= w*h) return;
    int padding = kernerl_size/2;
    int offset = - padding /2;
    int stride = w*h;
    int grid_x = idx % w ;
    int grid_y = idx / w ;
    int cls,l,m;
    float c_x,c_y;
    for (cls = 0; cls < classes; ++cls )
    {
        int objIndex = stride * cls + idx;
        float objProb = hm[objIndex];
        float max=-1;
        int max_index =0;
        for(l=0 ;l < kernerl_size ; ++l)
            for(m=0 ; m < kernerl_size ; ++m){
                int cur_x = offset + l + grid_x;
                int cur_y = offset + m + grid_y;
                int cur_index = cur_y * w + cur_x + stride*cls;
                int valid = (cur_x>=0 && cur_x < w && cur_y >=0 && cur_y <h );
                float val = (valid !=0 ) ? Logist(hm[cur_index]): -1;
                max_index = (val > max) ? cur_index : max_index;
                max = (val > max ) ?  val: max ;
            }
        objProb = Logist(objProb);
        if((max_index == objIndex) && (objProb > visthresh)){int resCount = (int)atomicAdd(output,1);
            //printf("%d",resCount);
            char* data = (char * )output + sizeof(float) + resCount*sizeof(Detection);
            Detection* det =  (Detection*)(data);
            c_x = grid_x + reg[idx] ; c_y  = grid_y + reg[idx+stride];
            det->bbox.x1 = (c_x - wh[idx]/2)*4;
            det->bbox.y1 = (c_y - wh[idx+stride]/2)*4 ;
            det->bbox.x2 = (c_x + wh[idx]/2)*4;
            det->bbox.y2 = (c_y + wh[idx+stride]/2)*4;
            det->classId = cls;
            det->prob = objProb;
        }
    }
}

需要注意的是,这里面的kernelsize就是maxpool的边长,通过设定不同的值你可以控制整个网络的粒度。最后我们通过一个threshold在这部分里面踢掉了一些置信度比较低的heappoint。
最后将结果保存为一个Detection 的数据结构中。
就这么简单,这边是CenterNet TensorRT加速的核心。最终我们也实现了通过这个后处理可以完美的得到和python一样的结果,同时拥有更快的推理速度。

未完待续

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值