总结一下NMS算法C++代码以及Python代码的实现(已测试过)
NMS算法是目标检测中取出冗余检测框的常用算法,它的基本步骤为:
(1)选择某一类物体的所有的检测框和置信度,将其放到一个容器中
(2)对检测框的置信度进行降序排序
(3)选择容器中,检测框的置信度最大的bbox,将其保存下来,然后与容器中剩余的元素依次进行IOU计算
(4)如果IOU计算的结果大于置信度阈值的话,将该检测框及其置信度从容器中剔除出去
(5)重复3-4步骤,直至容器为空,才停止算法。
1 C++代码实现NMS
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
// 思路:进行nms
// (1) 将所有的bbox根据置信度进行降序排序
// (2) 挑选conf最大的元素,如果conf小于conf阈值的话,我们就退出
// (3) 否则,将其索引保存到数组中,与其他剩余的bbox进行iou计算
// (4) 如果iou大于阈值,我们就将其从数组中剔除出去
// (5) 重复2-4,直到数组为空,我们才退出
struct Rect {
float x1;
float y1;
float x2;
float y2;
Rect() :x1(0), y1(0), x2(0), y2(0) {}
Rect(float x1, float y1, float x2, float y2) :x1(x1), y1(y1), x2(x2), y2(y2) {}
};
struct BBox {
int idx;
float conf;
Rect rect;
BBox() :idx(0), conf(0), rect(0, 0, 0, 0) {}
BBox(int idx, float conf, Rect rect) : idx(idx), conf(conf), rect(rect) {}
};
// 进行iou计算
float iou(Rect& r1, Rect& r2) {
float area1 = (r1.x2 - r1.x1) * (r1.y2 - r1.y1), area2 = (r2.x2 - r2.x1) * (r2.y2 - r2.y1);
float xx1 = max(r1.x1, r2.x1), yy1 = max(r1.y1, r2.y1), xx2 = min(r1.x2, r2.x2), yy2 = min(r1.y2, r2.y2);
float w = max(0.0f, xx2 - xx1), h = max(0.0f, yy2 - yy1);
float inter_area = h * w, union_area = area1 + area2 - inter_area;
return inter_area / union_area;
}
// 进行nms计算
vector<Rect> nms(vector<Rect>& rects, vector<float>& confs, float conf_thresh = 0.5, float iou_thresh = 0.5) {
vector<BBox> bboxes;
BBox bbox;
vector<int> keep_idx;
vector<Rect> ans;
for (int i = 0; i < (int)rects.size(); ++i) {
bboxes.push_back(BBox(i, confs[i], rects[i]));
}
// 对bbox的conf进行降序排序
sort(bboxes.begin(), bboxes.end(), [&](const BBox& a, const BBox& b) {
return a.conf > b.conf;
});
while (!bboxes.empty()) {
bbox = bboxes[0];
if (bbox.conf < conf_thresh) {
break;
}
keep_idx.emplace_back(bbox.idx);
bboxes.erase(bboxes.begin());
// 让conf最高的bbox与其他剩余的bbox进行iou计算
int size = bboxes.size();
for (int i = 0; i < size; ++i) {
float iou_ans = iou(bbox.rect, bboxes[i].rect);
if (iou_ans > iou_thresh) {
bboxes.erase(bboxes.begin() + i);
size = bboxes.size();
i = i - 1;
}
}
}
// keep_idx的输出结果是[2,1,0]
for (const int number : keep_idx) {
ans.push_back(rects[number]);
}
return ans;
}
int main() {
vector<Rect> rect = { Rect(100.0f, 100.0f, 210.0f, 210.0f),Rect(250.0f, 250.0f, 420.0f, 420.0f), Rect(220.0f, 220.0f, 320.0f, 330.0f),
Rect(100.0f, 100.0f, 210.0f, 210.0f),Rect(230.0f, 240.0f, 325.0f, 330.0f) ,Rect(220.0f, 230.0f, 315.0f, 340.0f) };
vector<float> conf = { 0.72f ,0.8f ,0.92f ,0.72f ,0.81f ,0.9f };
nms(rect, conf);
return 0;
}
c++代码实现思路比较清晰,但是在写代码的时候有一点比较重要,就是vector的erase函数。当我们在for循环中erase某个元素的时候,我们需要重新计算该容器的size,如果不更新该容器的size话,一旦我们erase某个元素,但是还按照为删除之前的size进行遍历的话,就会造成数组越界;另外一点,当iou大于iou阈值的时候,我们将i变化成了i-1,这是因为当第i个元素被删除之后,第i+1个元素就变成了第i个元素,而for循环处理之后又会将i递增,所以我们需要将i变成i-1。
2. Python实现NMS
import numpy as np
bboxes = np.array([[100, 100, 210, 210, 0.72],
[250, 250, 420, 420, 0.8],
[220, 220, 320, 330, 0.92],
[100, 100, 210, 210, 0.72],
[230, 240, 325, 330, 0.81],
[220, 230, 315, 340, 0.9]])
# 进行非极大值抑制
def nms(iou_thresh=0.5, conf_threash = 0.5):
# 基本思路:
# (1) 将置信度进行降序排序,然后选择置信度最大的bbox,将其保存下来
# (2) 将置信度最大的bbox和其他剩余的bbox进行交并比计算,将交并比大于阈值的bbox从这个集合中剔除出去
# (3) 如果这个集合不为空的话,我们就重复上面的计算
# 为了提高效率,我们保留bbox不动,最终保留的也都是bbox在原集合中的索引
x1, y1, x2, y2, confidence = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3], bboxes[:, 4]
# 计算面积
area = (x2 - x1) * (y2 - y1)
indices = confidence.argsort()[::-1]
keep = []
while indices.size > 0:
idx_self, idx_other = indices[0], indices[1:]
# 如果置信度小于阈值的话,那么后面的bbox就都不符合要求了,直接退出就行了
if confidence[idx_self] < conf_threash:
break
keep.append(idx_self)
# 计算交集
xx1, yy1 = np.maximum(x1[idx_self], x1[idx_other]), np.maximum(y1[idx_self], y1[idx_other])
xx2, yy2 = np.minimum(x2[idx_self], x2[idx_other]), np.minimum(y2[idx_self], y2[idx_other])
w, h = np.maximum(0, xx2 - xx1), np.maximum(0, yy2 - yy1)
intersection = w * h
# 计算并集(两个面积和-交集)
union = area[idx_self] + area[idx_other] - intersection
iou = intersection / union
# 只保留iou小于等于阈值的元素
print(np.where(iou <= iou_thresh))
keep_idx = np.where(iou <= iou_thresh)[0]
indices = indices[keep_idx + 1]
return np.array(keep)
# 进行非极大值抑制
if __name__ == '__main__':
print(nms())
使用python编写nms算法就非常简便了,就不需要使用struct等等相对繁琐的操作了。而中最关键的两个函数就是argsort函数和where函数了。argsort函数返回经过排序之后,有序数组各个位置上是原数组的哪个元素的索引,这就意味着我们不需要对bbox进行排序,只需要对置信度进行排序(对比一下c++中sort函数,它就需要对原来数组进行改变,当然在c++中你也可以使用与这相似的思想进行排序);where函数就是返回数组中符合某个条件的元素的索引,where返回的是一个元素,我们只需要第0个元素就行了。