非极大值抑制(nonMaximumSuppression)

博客地址:http://blog.csdn.net/qq_14845119/article/details/52064928


理论基础

         说实话,讲理论基础实在不是我的强项,但是还是得硬着头皮来讲,希望我的讲解不至于晦涩难懂。

         非极大值抑制,简称为NMS算法。是一种获取局部最大值的有效方法。在3领域中,假设一个行向量的长度为w,从左向右,由第一个到第w个和其3领域中的数值进行比对。

如果某个i大于i+1并且小于i-1,则其为一个绝不最大值,同时也就意味着i+1不是一个局部最大值,所以将i移动2个步长,从i+2开始继续向后进行比较判断。如果某个i不满足上述条件,则将i+1,继续对i+1进行比对。当比对到最后一个w时,直接将w设置为局部最大值。算法流程如下图所示。

应用范围

         非极大值抑制NMS在目标检测,定位等领域是一种被广泛使用的方法。对于目标具体位置定位过程,不管是使用sw(sliding Window)还是ss(selective search)方法,都会产生好多的候选区域。实际看到的情形就是好多区域的交叉重叠,难以满足实际的应用。如下图所示。

针对该问题有3种传统的解决思路。

         第一种,选取好多矩形框的交集,即公共区域作为最后的目标区域。

         第二种,选取好多矩形框的并集,即所有矩形框的最小外截矩作为目标区域。当然这里也不是只要相交就直接取并集,需要相交的框满足交集占最小框的面积达到一定比例(也就是阈值)才合并。

         第三种,也就是本文的NMS,简单的说,对于有相交的就选取其中置信度最高的一个作为最后结果,对于没相交的就直接保留下来,作为最后结果。

         总体来说,3种处理思路都各有千秋,不能一概评论哪种好坏。各种顶会论文也会选择不同的处理方法。

程序实现

         本文提供了nonMaximumSuppression的C语言,c++,M语言,三个版本。

         其中,c语言版本为OpenCV的源码这里摘出并进行相关的注释。sort为排序函数,这里是最基本的选择排序算法的实现。nonMaximumSuppression为具体非极大值抑制的实现。

[cpp]  view plain  copy
  1. static void sort(int n, const float* x, int* indices)  
  2. {  
  3. // 排序函数,排序后进行交换的是indices中的数据  
  4. // n:排序总数// x:带排序数// indices:初始为0~n-1数目   
  5.   
  6.     int i, j;  
  7.     for (i = 0; i < n; i++)  
  8.         for (j = i + 1; j < n; j++)  
  9.         {  
  10.             if (x[indices[j]] > x[indices[i]])  
  11.             {  
  12.                 //float x_tmp = x[i];  
  13.                 int index_tmp = indices[i];  
  14.                 //x[i] = x[j];  
  15.                 indices[i] = indices[j];  
  16.                 //x[j] = x_tmp;  
  17.                 indices[j] = index_tmp;  
  18.             }  
  19.         }  
  20. }  


[cpp]  view plain  copy
  1. int nonMaximumSuppression(int numBoxes, const CvPoint *points,  
  2.                           const CvPoint *oppositePoints, const float *score,  
  3.                           float overlapThreshold,  
  4.                           int *numBoxesOut, CvPoint **pointsOut,  
  5.                           CvPoint **oppositePointsOut, float **scoreOut)  
  6. {  
  7.   
  8. // numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点  
  9. // score:窗口得分// overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目  
  10. // pointsOut:输出窗口左上角坐标点// oppositePoints:输出窗口右下角坐标点  
  11. // scoreOut:输出窗口得分  
  12.     int i, j, index;  
  13.     float* box_area = (float*)malloc(numBoxes * sizeof(float));    // 定义窗口面积变量并分配空间   
  14.     int* indices = (int*)malloc(numBoxes * sizeof(int));          // 定义窗口索引并分配空间   
  15.     int* is_suppressed = (int*)malloc(numBoxes * sizeof(int));    // 定义是否抑制表标志并分配空间   
  16.     // 初始化indices、is_supperssed、box_area信息   
  17.     for (i = 0; i < numBoxes; i++)  
  18.     {  
  19.         indices[i] = i;  
  20.         is_suppressed[i] = 0;  
  21.         box_area[i] = (float)( (oppositePoints[i].x - points[i].x + 1) *  
  22.                                 (oppositePoints[i].y - points[i].y + 1));  
  23.     }  
  24.     // 对输入窗口按照分数比值进行排序,排序后的编号放在indices中   
  25.     sort(numBoxes, score, indices);  
  26.     for (i = 0; i < numBoxes; i++)                // 循环所有窗口   
  27.     {  
  28.         if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制   
  29.         {  
  30.             for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口   
  31.             {  
  32.                 if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制   
  33.                 {  
  34.                     int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值   
  35.                     int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值   
  36.                     int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值   
  37.                     int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值   
  38.                     int overlapWidth = x2min - x1max + 1;            // 计算两矩形重叠的宽度   
  39.                     int overlapHeight = y2min - y1max + 1;           // 计算两矩形重叠的高度   
  40.                     if (overlapWidth > 0 && overlapHeight > 0)  
  41.                     {  
  42.                         float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率   
  43.                         if (overlapPart > overlapThreshold)          // 判断重叠比率是否超过重叠阈值   
  44.                         {  
  45.                             is_suppressed[indices[j]] = 1;           // 将窗口j标记为抑制   
  46.                         }  
  47.                     }  
  48.                 }  
  49.             }  
  50.         }  
  51.     }  
  52.   
  53.     *numBoxesOut = 0;    // 初始化输出窗口数目0   
  54.     for (i = 0; i < numBoxes; i++)  
  55.     {  
  56.         if (!is_suppressed[i]) (*numBoxesOut)++;    // 统计输出窗口数目   
  57.     }  
  58.   
  59.     *pointsOut = (CvPoint *)malloc((*numBoxesOut) * sizeof(CvPoint));           // 分配输出窗口左上角坐标空间   
  60.     *oppositePointsOut = (CvPoint *)malloc((*numBoxesOut) * sizeof(CvPoint));   // 分配输出窗口右下角坐标空间   
  61.     *scoreOut = (float *)malloc((*numBoxesOut) * sizeof(float));                // 分配输出窗口得分空间   
  62.     index = 0;  
  63.     for (i = 0; i < numBoxes; i++)                  // 遍历所有输入窗口   
  64.     {  
  65.         if (!is_suppressed[indices[i]])             // 将未发生抑制的窗口信息保存到输出信息中   
  66.         {  
  67.             (*pointsOut)[index].x = points[indices[i]].x;  
  68.             (*pointsOut)[index].y = points[indices[i]].y;  
  69.             (*oppositePointsOut)[index].x = oppositePoints[indices[i]].x;  
  70.             (*oppositePointsOut)[index].y = oppositePoints[indices[i]].y;  
  71.             (*scoreOut)[index] = score[indices[i]];  
  72.             index++;  
  73.         }  
  74.   
  75.     }  
  76.   
  77.     free(indices);          // 释放indices空间   
  78.     free(box_area);         // 释放box_area空间   
  79.     free(is_suppressed);    // 释放is_suppressed空间   
  80.   
  81.     return LATENT_SVM_OK;  
  82. }  


c++版本程序如下所示,根据opencv源码改编,vs2010实测运行完美。由于c和c++版本基本一个思路,因此将这两个的思路一起讲解。

         整体程序分为两部分,sort函数主要实现候选框的置信度从高到低的排序,是基于基本的选择排序实现。nonMaximumSuppression主要实现非极大值抑制算法。算法思路为,先根据候选框的points 和oppositePoints 求出每个候选框的面积box_area,并将标签is_suppressed全部置为0。通过一个二重for循环将所有的候选框进行比对,这里的循环是从置信度最高的窗口进行比对,每层外循环中置信度最高的保留,其余的只要大于规定阈值overlapThreshold就舍弃,不大于阈值的保留下来。最终输出NMS处理后的结果。

[cpp]  view plain  copy
  1. static void sort(int n, const vector<float> x, vector<int> indices)  
  2. {  
  3. // 排序函数,排序后进行交换的是indices中的数据  
  4. // n:排序总数// x:带排序数// indices:初始为0~n-1数目   
  5.       
  6.     int i, j;  
  7.     for (i = 0; i < n; i++)  
  8.         for (j = i + 1; j < n; j++)  
  9.         {  
  10.             if (x[indices[j]] > x[indices[i]])  
  11.             {  
  12.                 //float x_tmp = x[i];  
  13.                 int index_tmp = indices[i];  
  14.                 //x[i] = x[j];  
  15.                 indices[i] = indices[j];  
  16.                 //x[j] = x_tmp;  
  17.                 indices[j] = index_tmp;  
  18.             }  
  19.         }  
  20. }  


[cpp]  view plain  copy
  1. int nonMaximumSuppression(int numBoxes, const vector<CvPoint> points,const vector<CvPoint> oppositePoints,   
  2.     const vector<float> score,    float overlapThreshold,int& numBoxesOut, vector<CvPoint>& pointsOut,  
  3.     vector<CvPoint>& oppositePointsOut, vector<float> scoreOut)   
  4. {  
  5. // 实现检测出的矩形窗口的非极大值抑制nms  
  6. // numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点// score:窗口得分  
  7. // overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目// pointsOut:输出窗口左上角坐标点  
  8. // oppositePoints:输出窗口右下角坐标点// scoreOut:输出窗口得分  
  9.     int i, j, index;  
  10.     vector<float> box_area(numBoxes);             // 定义窗口面积变量并分配空间   
  11.     vector<int> indices(numBoxes);                    // 定义窗口索引并分配空间   
  12.     vector<int> is_suppressed(numBoxes);          // 定义是否抑制表标志并分配空间   
  13.     // 初始化indices、is_supperssed、box_area信息   
  14.     for (i = 0; i < numBoxes; i++)  
  15.     {  
  16.         indices[i] = i;  
  17.         is_suppressed[i] = 0;  
  18.         box_area[i] = (float)( (oppositePoints[i].x - points[i].x + 1) *(oppositePoints[i].y - points[i].y + 1));  
  19.     }  
  20.     // 对输入窗口按照分数比值进行排序,排序后的编号放在indices中   
  21.     sort(numBoxes, score, indices);  
  22.     for (i = 0; i < numBoxes; i++)                // 循环所有窗口   
  23.     {  
  24.         if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制   
  25.         {  
  26.             for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口   
  27.             {  
  28.                 if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制   
  29.                 {  
  30.                     int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值   
  31.                     int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值   
  32.                     int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值   
  33.                     int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值   
  34.                     int overlapWidth = x2min - x1max + 1;     // 计算两矩形重叠的宽度   
  35.                     int overlapHeight = y2min - y1max + 1;     // 计算两矩形重叠的高度   
  36.                     if (overlapWidth > 0 && overlapHeight > 0)  
  37.                     {  
  38.                         float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率   
  39.                         if (overlapPart > overlapThreshold)   // 判断重叠比率是否超过重叠阈值   
  40.                         {  
  41.                             is_suppressed[indices[j]] = 1;     // 将窗口j标记为抑制   
  42.                         }  
  43.                     }  
  44.                 }  
  45.             }  
  46.         }  
  47.     }  
  48.   
  49.     numBoxesOut = 0;    // 初始化输出窗口数目0   
  50.     for (i = 0; i < numBoxes; i++)  
  51.     {  
  52.         if (!is_suppressed[i]) numBoxesOut++;    // 统计输出窗口数目   
  53.     }  
  54.     index = 0;  
  55.     for (i = 0; i < numBoxes; i++)            // 遍历所有输入窗口   
  56.     {  
  57.         if (!is_suppressed[indices[i]])       // 将未发生抑制的窗口信息保存到输出信息中   
  58.         {  
  59.             pointsOut.push_back(Point(points[indices[i]].x,points[indices[i]].y));  
  60.             oppositePointsOut.push_back(Point(oppositePoints[indices[i]].x,oppositePoints[indices[i]].y));  
  61.             scoreOut.push_back(score[indices[i]]);  
  62.             index++;  
  63.         }  
  64.   
  65.     }  
  66.   
  67.     return true;  
  68. }  

2、更好的优化代码如下:


bool CompareBBox(const seeta::FaceInfo & a, const seeta::FaceInfo & b) {
  return a.score > b.score;
}

void NonMaximumSuppression(std::vector<seeta::FaceInfo>* bboxes,
  std::vector<seeta::FaceInfo>* bboxes_nms, float iou_thresh) {
  bboxes_nms->clear();
  std::sort(bboxes->begin(), bboxes->end(), seeta::fd::CompareBBox);

  int32_t select_idx = 0;
  int32_t num_bbox = static_cast<int32_t>(bboxes->size());
  std::vector<int32_t> mask_merged(num_bbox, 0);
  bool all_merged = false;

  while (!all_merged) {
    while (select_idx < num_bbox && mask_merged[select_idx] == 1)
      select_idx++;
    if (select_idx == num_bbox) {
      all_merged = true;
      continue;
    }

    bboxes_nms->push_back((*bboxes)[select_idx]);
    mask_merged[select_idx] = 1;

    seeta::Rect select_bbox = (*bboxes)[select_idx].bbox;
    float area1 = static_cast<float>(select_bbox.width * select_bbox.height);
    float x1 = static_cast<float>(select_bbox.x);
    float y1 = static_cast<float>(select_bbox.y);
    float x2 = static_cast<float>(select_bbox.x + select_bbox.width - 1);
    float y2 = static_cast<float>(select_bbox.y + select_bbox.height - 1);

    select_idx++;
    for (int32_t i = select_idx; i < num_bbox; i++) {
      if (mask_merged[i] == 1)
        continue;

      seeta::Rect & bbox_i = (*bboxes)[i].bbox;
      float x = std::max<float>(x1, static_cast<float>(bbox_i.x));
      float y = std::max<float>(y1, static_cast<float>(bbox_i.y));
      float w = std::min<float>(x2, static_cast<float>(bbox_i.x + bbox_i.width - 1)) - x + 1;
      float h = std::min<float>(y2, static_cast<float>(bbox_i.y + bbox_i.height - 1)) - y + 1;
      if (w <= 0 || h <= 0)
        continue;

      float area2 = static_cast<float>(bbox_i.width * bbox_i.height);
      float area_intersect = w * h;
      float area_union = area1 + area2 - area_intersect;
      if (static_cast<float>(area_intersect) / area_union > iou_thresh) {
        mask_merged[i] = 1;
        bboxes_nms->back().score += (*bboxes)[i].score;
      }
    }
  }
}



实验结果

      c++和Matlab的测试结果如下所示,其中,红色框为经过NMS处理后的结果,黄色框为原始的输入框。从图中可以看出,经过NMS处理后的候选框中,在重叠部分大于规定阈值的都被舍弃,只保留其中置信度最高的一个,而对于没有重叠的框,不管其置信度多少,都直接保留下来。

       注意,在matlab,opencv里面图像左上角为坐标原点,而本文在matlab中是单纯画图,此时图像左下角为坐标原点,所以同样的坐标,两幅图效果有所区别。

本文所有程序github下载链接https://github.com/watersink/nonMaximumSuppression

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值