非极大值抑制(Non-Maximum-Suppression)

转自:https://blog.csdn.net/u014365862/article/details/53376516

非极大值抑制(Non-Maximum-Suppression)


注意看哦,有两个版本的。


理论基础

         说实话,讲理论基础实在不是我的强项,但是还是得硬着头皮来讲,希望我的讲解不至于晦涩难懂。

         非极大值抑制,简称为NMS算法。是一种获取局部最大值的有效方法。在3领域中,假设一个行向量的长度为w,从左向右,由第一个到第w个和其3领域中的数值进行比对。

如果某个i大于i+1并且小于i-1,则其为一个绝不最大值,同时也就意味着i+1不是一个局部最大值,所以将i移动2个步长,从i+2开始继续向后进行比较判断。如果某个i不满足上述条件,则将i+1,继续对i+1进行比对。当比对到最后一个w时,直接将w设置为局部最大值。算法流程如下图所示。

应用范围

         非极大值抑制NMS在目标检测,定位等领域是一种被广泛使用的方法。对于目标具体位置定位过程,不管是使用sw(sliding Window)还是ss(selective search)方法,都会产生好多的候选区域。实际看到的情形就是好多区域的交叉重叠,难以满足实际的应用。如下图所示。

针对该问题有3种传统的解决思路。

         第一种,选取好多矩形框的交集,即公共区域作为最后的目标区域。

         第二种,选取好多矩形框的并集,即所有矩形框的最小外截矩作为目标区域。当然这里也不是只要相交就直接取并集,需要相交的框满足交集占最小框的面积达到一定比例(也就是阈值)才合并。

         第三种,也就是本文的NMS,简单的说,对于有相交的就选取其中置信度最高的一个作为最后结果,对于没相交的就直接保留下来,作为最后结果。

         总体来说,3种处理思路都各有千秋,不能一概评论哪种好坏。各种顶会论文也会选择不同的处理方法。



[cpp]  view plain  copy
  1. <span style="font-size:14px;">#include <iostream>  
  2. #include <opencv2/core/core.hpp>  
  3. #include <opencv2/highgui/highgui.hpp>  
  4. #include <opencv2/opencv.hpp>  
  5. // 新版本写在下面文件中:  
  6. #include <opencv2/nonfree/features2d.hpp>  
  7. //#include "opencv2/features2d/features2d.hpp"  
  8. #include<opencv2/legacy/legacy.hpp>  
  9.   
  10. using namespace std;  
  11. using namespace cv;  
  12.   
  13.   
  14. void nms(  
  15.          const std::vector<cv::Rect>& srcRects,  
  16.          std::vector<cv::Rect>& resRects,  
  17.          float thresh  
  18.          )  
  19. {  
  20.     resRects.clear();  
  21.       
  22.     const size_t size = srcRects.size();  
  23.     if (!size)  
  24.     {  
  25.         return;  
  26.     }  
  27.       
  28.     // Sort the bounding boxes by the bottom - right y - coordinate of the bounding box  
  29.     std::multimap<intsize_t> idxs;  
  30.     for (size_t i = 0; i < size; ++i)  
  31.     {  
  32.         idxs.insert(std::pair<intsize_t>(srcRects[i].br().y, i));  
  33.     }  
  34.       
  35.     // keep looping while some indexes still remain in the indexes list  
  36.     while (idxs.size() > 0)  
  37.     {  
  38.         // grab the last rectangle  
  39.         auto lastElem = --std::end(idxs);  
  40.         const cv::Rect& rect1 = srcRects[lastElem->second];  
  41.           
  42.         resRects.push_back(rect1);  
  43.           
  44.         idxs.erase(lastElem);  
  45.           
  46.         for (auto pos = std::begin(idxs); pos != std::end(idxs); )  
  47.         {  
  48.             // grab the current rectangle  
  49.             const cv::Rect& rect2 = srcRects[pos->second];  
  50.               
  51.             float intArea = (rect1 & rect2).area();  
  52.             float unionArea = rect1.area() + rect2.area() - intArea;  
  53.             float overlap = intArea / unionArea;  
  54.               
  55.             // if there is sufficient overlap, suppress the current bounding box  
  56.             if (overlap > thresh)  
  57.             {  
  58.                 pos = idxs.erase(pos);  
  59.             }  
  60.             else  
  61.             {  
  62.                 ++pos;  
  63.             }  
  64.         }  
  65.     }  
  66. }  
  67.   
  68.   
  69. /** 
  70.  ******************************************************************************* 
  71.  * 
  72.  *   main 
  73.  * 
  74.  ******************************************************************************* 
  75.  */  
  76. int main(int argc, char* argv[])  
  77. {  
  78.     std::vector<cv::Rect> srcRects;  
  79.       
  80.     /* 
  81.      // Test 1 
  82.      srcRects.push_back(cv::Rect(cv::Point(114, 60), cv::Point(178, 124))); 
  83.      srcRects.push_back(cv::Rect(cv::Point(120, 60), cv::Point(184, 124))); 
  84.      srcRects.push_back(cv::Rect(cv::Point(114, 66), cv::Point(178, 130)));*/  
  85.       
  86.     /* 
  87.      // Test 2 
  88.      srcRects.push_back(cv::Rect(cv::Point(12, 84), cv::Point(140, 212))); 
  89.      srcRects.push_back(cv::Rect(cv::Point(24, 84), cv::Point(152, 212))); 
  90.      srcRects.push_back(cv::Rect(cv::Point(12, 96), cv::Point(140, 224))); 
  91.      srcRects.push_back(cv::Rect(cv::Point(36, 84), cv::Point(164, 212))); 
  92.      srcRects.push_back(cv::Rect(cv::Point(24, 96), cv::Point(152, 224))); 
  93.      srcRects.push_back(cv::Rect(cv::Point(24, 108), cv::Point(152, 236)));*/  
  94.       
  95.     // Test 3  
  96.     srcRects.push_back(cv::Rect(cv::Point(12, 30), cv::Point(76, 94)));  
  97.     srcRects.push_back(cv::Rect(cv::Point(12, 36), cv::Point(76, 100)));  
  98.     srcRects.push_back(cv::Rect(cv::Point(72, 36), cv::Point(200, 164)));  
  99.     srcRects.push_back(cv::Rect(cv::Point(84, 48), cv::Point(212, 176)));  
  100.       
  101.     cv::Size size(0, 0);  
  102.     for (const auto& r : srcRects)  
  103.     {  
  104.         size.width = std::max(size.width, r.x + r.width);  
  105.         size.height = std::max(size.height, r.y + r.height);  
  106.     }  
  107.       
  108.     cv::Mat img = cv::Mat(2 * size.height, 2 * size.width, CV_8UC3, cv::Scalar(0, 0, 0));  
  109.       
  110.     cv::Mat imgCopy = img.clone();  
  111.       
  112.       
  113.       
  114.     for (auto r : srcRects)  
  115.     {  
  116.         cv::rectangle(img, r, cv::Scalar(0, 0, 255), 2);  
  117.     }  
  118.       
  119.     cv::namedWindow("before", cv::WINDOW_NORMAL);  
  120.     cv::imshow("before", img);  
  121.     cv::waitKey(1);  
  122.       
  123.     std::vector<cv::Rect> resRects;  
  124.     nms(srcRects, resRects, 0.3f);  
  125.       
  126.     for (auto r : resRects)  
  127.     {  
  128.         cv::rectangle(imgCopy, r, cv::Scalar(0, 255, 0), 2);  
  129.     }  
  130.       
  131.     cv::namedWindow("after", cv::WINDOW_NORMAL);  
  132.     cv::imshow("after", imgCopy);  
  133.       
  134.     cv::waitKey(0);  
  135.       
  136.     return 0;  
  137. }</span><span style="font-size:11px;">  
  138. </span>  

实验结果:


[cpp]  view plain  copy
  1. <span style="font-size:14px;">#include <iostream>  
  2. #include <opencv2/core/core.hpp>  
  3. #include <opencv2/highgui/highgui.hpp>  
  4. #include <opencv2/opencv.hpp>  
  5. // 新版本写在下面文件中:  
  6. #include <opencv2/nonfree/features2d.hpp>  
  7. //#include "opencv2/features2d/features2d.hpp"  
  8. #include<opencv2/legacy/legacy.hpp>  
  9.   
  10. using namespace std;  
  11. using namespace cv;  
  12.   
  13.   
  14.   
  15. static void sort(int n, const vector<float> x, vector<int> indices)  
  16. {  
  17.     // 排序函数,排序后进行交换的是indices中的数据  
  18.     // n:排序总数// x:带排序数// indices:初始为0~n-1数目  
  19.       
  20.     int i, j;  
  21.     for (i = 0; i < n; i++)  
  22.         for (j = i + 1; j < n; j++)  
  23.         {  
  24.             if (x[indices[j]] > x[indices[i]])  
  25.             {  
  26.                 //float x_tmp = x[i];  
  27.                 int index_tmp = indices[i];  
  28.                 //x[i] = x[j];  
  29.                 indices[i] = indices[j];  
  30.                 //x[j] = x_tmp;  
  31.                 indices[j] = index_tmp;  
  32.             }  
  33.         }  
  34. }  
  35.   
  36.   
  37.   
  38. int nonMaximumSuppression(int numBoxes, const vector<CvPoint> points,const vector<CvPoint> oppositePoints,  
  39.                           const vector<float> score,  float overlapThreshold,int& numBoxesOut, vector<CvPoint>& pointsOut,  
  40.                           vector<CvPoint>& oppositePointsOut, vector<float> scoreOut)  
  41. {  
  42.     // 实现检测出的矩形窗口的非极大值抑制nms  
  43.     // numBoxes:窗口数目// points:窗口左上角坐标点// oppositePoints:窗口右下角坐标点// score:窗口得分  
  44.     // overlapThreshold:重叠阈值控制// numBoxesOut:输出窗口数目// pointsOut:输出窗口左上角坐标点  
  45.     // oppositePoints:输出窗口右下角坐标点// scoreOut:输出窗口得分  
  46.     int i, j, index;  
  47.     vector<float> box_area(numBoxes);             // 定义窗口面积变量并分配空间  
  48.     vector<int> indices(numBoxes);                    // 定义窗口索引并分配空间  
  49.     vector<int> is_suppressed(numBoxes);          // 定义是否抑制表标志并分配空间  
  50.     // 初始化indices、is_supperssed、box_area信息  
  51.     for (i = 0; i < numBoxes; i++)  
  52.     {  
  53.         indices[i] = i;  
  54.         is_suppressed[i] = 0;  
  55.         box_area[i] = (float)( (oppositePoints[i].x - points[i].x + 1) *(oppositePoints[i].y - points[i].y + 1));  
  56.     }  
  57.     // 对输入窗口按照分数比值进行排序,排序后的编号放在indices中  
  58.     sort(numBoxes, score, indices);  
  59.     for (i = 0; i < numBoxes; i++)                // 循环所有窗口  
  60.     {  
  61.         if (!is_suppressed[indices[i]])           // 判断窗口是否被抑制  
  62.         {  
  63.             for (j = i + 1; j < numBoxes; j++)    // 循环当前窗口之后的窗口  
  64.             {  
  65.                 if (!is_suppressed[indices[j]])   // 判断窗口是否被抑制  
  66.                 {  
  67.                     int x1max = max(points[indices[i]].x, points[indices[j]].x);                     // 求两个窗口左上角x坐标最大值  
  68.                     int x2min = min(oppositePoints[indices[i]].x, oppositePoints[indices[j]].x);     // 求两个窗口右下角x坐标最小值  
  69.                     int y1max = max(points[indices[i]].y, points[indices[j]].y);                     // 求两个窗口左上角y坐标最大值  
  70.                     int y2min = min(oppositePoints[indices[i]].y, oppositePoints[indices[j]].y);     // 求两个窗口右下角y坐标最小值  
  71.                     int overlapWidth = x2min - x1max + 1;            // 计算两矩形重叠的宽度  
  72.                     int overlapHeight = y2min - y1max + 1;           // 计算两矩形重叠的高度  
  73.                     if (overlapWidth > 0 && overlapHeight > 0)  
  74.                     {  
  75.                         float overlapPart = (overlapWidth * overlapHeight) / box_area[indices[j]];    // 计算重叠的比率  
  76.                         if (overlapPart > overlapThreshold)          // 判断重叠比率是否超过重叠阈值  
  77.                         {  
  78.                             is_suppressed[indices[j]] = 1;           // 将窗口j标记为抑制  
  79.                         }  
  80.                     }  
  81.                 }  
  82.             }  
  83.         }  
  84.     }  
  85.       
  86.     numBoxesOut = 0;    // 初始化输出窗口数目0  
  87.     for (i = 0; i < numBoxes; i++)  
  88.     {  
  89.         if (!is_suppressed[i]) numBoxesOut++;    // 统计输出窗口数目  
  90.     }  
  91.     index = 0;  
  92.     for (i = 0; i < numBoxes; i++)                  // 遍历所有输入窗口  
  93.     {  
  94.         if (!is_suppressed[indices[i]])             // 将未发生抑制的窗口信息保存到输出信息中  
  95.         {  
  96.             pointsOut.push_back(Point(points[indices[i]].x,points[indices[i]].y));  
  97.             oppositePointsOut.push_back(Point(oppositePoints[indices[i]].x,oppositePoints[indices[i]].y));  
  98.             scoreOut.push_back(score[indices[i]]);  
  99.             index++;  
  100.         }  
  101.           
  102.     }  
  103.       
  104.     return true;  
  105. }  
  106.   
  107. int main()  
  108. {  
  109.     Mat image=Mat::zeros(600,600,CV_8UC3);  
  110.     int numBoxes=4;  
  111.     vector<CvPoint> points(numBoxes);  
  112.     vector<CvPoint> oppositePoints(numBoxes);  
  113.     vector<float> score(numBoxes);  
  114.       
  115.     points[0]=Point(200,200);oppositePoints[0]=Point(400,400);score[0]=0.99;  
  116.     points[1]=Point(220,220);oppositePoints[1]=Point(420,420);score[1]=0.9;  
  117.     points[2]=Point(100,100);oppositePoints[2]=Point(150,150);score[2]=0.82;  
  118.     points[3]=Point(200,240);oppositePoints[3]=Point(400,440);score[3]=0.5;  
  119.       
  120.       
  121.     float overlapThreshold=0.8;  
  122.     int numBoxesOut;  
  123.     vector<CvPoint> pointsOut;  
  124.     vector<CvPoint> oppositePointsOut;  
  125.     vector<float> scoreOut;  
  126.       
  127.     nonMaximumSuppression( numBoxes,points,oppositePoints,score,overlapThreshold,numBoxesOut,pointsOut,oppositePointsOut,scoreOut);  
  128.     for (int i=0;i<numBoxes;i++)  
  129.     {  
  130.         rectangle(image,points[i],oppositePoints[i],Scalar(0,255,255),6);  
  131.         char text[20];  
  132.         sprintf(text,"%f",score[i]);  
  133.         putText(image,text,points[i],CV_FONT_HERSHEY_COMPLEX, 1,Scalar(0,255,255));  
  134.     }  
  135.     for (int i=0;i<numBoxesOut;i++)  
  136.     {  
  137.         rectangle(image,pointsOut[i],oppositePointsOut[i],Scalar(0,0,255),2);  
  138.     }  
  139.       
  140.     imshow("result",image);  
  141.       
  142.     waitKey();  
  143.     return 0;  
  144. }  
  145. </span>  

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值