目录
References
What is NMS?
NMS - Non-maximum suppression, 非极大值抑制
在图像识别目标检测任务中,通常用框框(Bounding Box)表示识别出来的对象,对于同一个对象可能有很多个有着不同重叠区域的Bounding Box。每个Bounding Box对应的属于该特征的概率也不尽相同(confidence,置信度)。
所以NMS是用来抑制 非 最大置信度 的识别结果,只留下来置信度高的(是该特征的概率最大)结果。
Relevant Concepts
- Bounding Box
- IoU - Intersection over Union, 交并比 : 两个Bounding Box的交集区域面积,除以并集区域面积
Example Logic
- input-Bounding Box List和IoU的阈值, output-Bounding List with NMS processing
- 按照confidence对Bounding Box从高到低排序
- 计算Bounding Box两两的IoU, 如果IoU大于阈值,从List删除当前两个BBox中confidence的低的那个
- 全部处理完,输出新的BBox List
Code
//NMS.H
#pragma once
void nms_test();
typedef struct BBox
{
int x; //(x,y) left-up position, for x, left is increasing
int y; // for y, down is increasing
int w; //width
int h; //hight
float confidence;
}BBox;
//NMS_test.cpp
#include "NMS.h"
#include <vector>
#include <algorithm>
#include <iostream>
float IOU(BBox box1, BBox box2)
{
float iou = 0;
int x_over_l = std::max(box1.x, box2.x); //left x position of over area
int x_over_r = std::min(box1.x + box1.w, box2.x + box2.w); //right position of over area
int y_over_u = std::max(box1.y, box2.y);
int y_over_d = std::min(box1.y + box1.h, box2.y + box2.h);
int over_w = x_over_r - x_over_l;
int over_h = y_over_d - y_over_u;
if (over_w <= 0 || over_h <= 0)
{
iou = 0;
}
else {
iou = (float)over_w * over_h / (box1.w * box1.h + box2.w * box2.h - over_w * over_h);
}
return iou;
}
std::vector<BBox> NMS(std::vector<BBox> &boxes, float threshold)
{
std::vector<BBox>resluts;
sort(boxes.begin(), boxes.end(), [](BBox a, BBox b) {return a.confidence > b.confidence;});
while (boxes.size() > 0)
{
resluts.push_back(boxes[0]);
int index = 1;
while (index < boxes.size()) {
float iou_value = IOU(boxes[0], boxes[index]);
std::cout << "iou_value=" << iou_value << '\n';
if (iou_value > threshold) {
boxes.erase(boxes.begin() + index);
}
else {
index++;
}
}
boxes.erase(boxes.begin());
}
return resluts;
}
void nms_test()
{
std::cout << "========NMS test begin======" << '\n';
std::vector<BBox> input;
BBox box1 = { 10,10,10,10,0.5 };
BBox box2 = { 0,0,20,20,0.6 };
input.push_back(box1);
input.push_back(box2);
std::vector<BBox> res;
res = NMS(input, 0.15);
for (int i = 0; i < res.size(); i++) {
printf("%d %d %d %d %f", res[i].x, res[i].y, res[i].w, res[i].h, res[i].confidence);
std::cout << '\n';
}
std::cout << "========NMS test end======" << '\n';
}
//main.cpp
#include "NMS.h"
int main()
{
nms_test();
}