#include <iostream>
#include <vector>
# include <algorithm>
using namespace std;
class NMS{
public:
class Bbox {
public:
int x1,x2,y1,y2;
float score;
};
float IOU(Bbox b1, Bbox b2) {
int maxX = max(b1.x1, b2.x1);
int minX = min(b1.x2, b2.x2);
int maxY = max(b1.y1, b2.y1);
int minY = min(b1.y2, b2.y2);
if (maxX >= minX || maxY >= minY) {
return 0; // no overlap
}
float numerator = (minX - maxX)*(minY - maxY);
float denominator = (b1.x2-b1.x1)*(b1.y2-b1.y1)+(b2.x2-b2.x1)*(b2.y2-b2.y1);
float iou = numerator/denominator;
return iou;
}
// calculate NMS
// define comparator
static bool cmp(const Bbox& b1, const Bbox& b2) {
return b1.score > b2.score;
}
// select the highest score as the base, compare iou with following boxes.
// if overlapping area > threshold, remove it.
vector<Bbox> nms(vector<Bbox>& box_set, float threshold) {
vector<Bbox> res;
sort(box_set.begin(), box_set.end(), cmp);
res.push_back(box_set[0]);
Bbox temp = box_set[0];
for (int i = 1; i < box_set.size(); i++) {
float iou = IOU(temp, box_set[i]);
if (iou > threshold) {
box_set.erase(box_set.begin()+i);
}
}
res.erase(res.begin());
return res;
}
};
c++手写NMS算法
最新推荐文章于 2024-05-17 13:26:14 发布