NMS代码c++版本
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
class NMS {
public:
NMS(float iouTheta, float scoreTheta, vector<vector<float>> cubesRawShapes)
: iouTheta(iouTheta), scoreTheta(scoreTheta), cubesRawShapes(cubesRawShapes) {}
vector<vector<float>> getboxShapesIouFiltered() { return boxShapesIouFiltered; }
float computeIou(const vector<float>& cubeShape1, const vector<float>& cubeShape2)
{
float h = min(cubeShape1[1], cubeShape2[1]) - max(cubeShape1[0], cubeShape2[0]);
float w = min(cubeShape1[3], cubeShape2[3]) - max(cubeShape1[2], cubeShape2[2]);
if (h <= 0 || w <= 0) return 0.0f;
float intersection = h * w;
float area1 = (cubeShape1[1] - cubeShape1[0]) * (cubeShape1[3] - cubeShape1[2]);
float area2 = (cubeShape2[1] - cubeShape2[0]) * (cubeShape2[3] - cubeShape2[2]);
return intersection / (area1 + area2 - intersection);
}
void nmsProcess()
{
for (const auto& cubeRawShapes : cubesRawShapes)
{
if (cubeRawShapes[4] >= scoreTheta)
{
boxShapesScoreFiltered.emplace_back(cubeRawShapes);
boxShapesScoreFilteredVis.emplace_back(true);
}
}
while (true)
{
int maxScoreShapeIdx = -1;
float maxScore = -1.0f;
for (int i = 0; i < boxShapesScoreFiltered.size(); ++i)
{
if (boxShapesScoreFilteredVis[i] && boxShapesScoreFiltered[i][4] > maxScore)
{
maxScoreShapeIdx = i;
maxScore = boxShapesScoreFiltered[i][4];
}
}
if (maxScoreShapeIdx == -1) break;
boxShapesScoreFilteredVis[maxScoreShapeIdx] = false;
boxShapesIouFiltered.emplace_back(boxShapesScoreFiltered[maxScoreShapeIdx]);
for (int i = 0; i < boxShapesScoreFiltered.size(); ++i)
{
if (boxShapesScoreFilteredVis[i] && computeIou(boxShapesIouFiltered.back(), boxShapesScoreFiltered[i]) >= iouTheta)
{
boxShapesScoreFilteredVis[i] = false;
}
}
}
}
private:
float iouTheta = 0.f;
float scoreTheta = 0.f;
vector<vector<float>> cubesRawShapes{};
vector<vector<float>> boxShapesScoreFiltered{};
vector<bool> boxShapesScoreFilteredVis{};
vector<vector<float>> boxShapesIouFiltered{};
};
int main()
{
vector<vector<float>> cubesRawShapes = { {30, 10, 200, 200, 0.95},
{25, 15, 180, 220, 0.98},
{35, 40, 190, 170, 0.96},
{60, 60, 90, 90, 0.3},
{20, 30, 40, 50, 0.1} };
const float iouTheta = 0.5;
const float scoreTheta = 0.5;
NMS nms(iouTheta, scoreTheta, cubesRawShapes);
nms.nmsProcess();
vector<vector<float>> cubesFilterShapes = nms.getboxShapesIouFiltered();
for (int i = 0; i < cubesFilterShapes.size(); i++)
{
for (int j = 0; j < cubesFilterShapes[i].size(); j++)
{
cout << cubesFilterShapes[i][j] << " ";
}
cout << endl;
}
return 0;
}