改之前,不区分类别的:
cv::dnn::NMSBoxes(boxes, confidences, conf_threshold, nms_threshold, indices);
改完,类内做nms,类别间不做nms
inline cv::Rect2d RemapBoxOnSrc(const cv::Rect2d &box, const int img_width, const int img_height)
{
float xmin = static_cast<float>(box.x);
float ymin = static_cast<float>(box.y);
float xmax = xmin + static_cast<float>(box.width);
float ymax = ymin + static_cast<float>(box.height);
cv::Rect2d remap_box;
remap_box.x = std::max(.0f, xmin);
remap_box.width = std::min(img_width - 1.0f, xmax) - remap_box.x;
remap_box.y = std::max(.0f, ymin);
remap_box.height = std::min(img_height - 1.0f, ymax) - remap_box.y;
return remap_box;
};
inline float intersectionArea(const bbox& box1, const bbox& box2){
float x1 = std::max(box1.left, box2.left);
float y1 = std::max(box1.top, box2.top);
float x2 = std::min(box1.left + box1.width, box2.left + box2.width);
float y2 = std::min(box1.top + box1.height, box2.top + box2.height);
if(x1 < x2 && y1 < y2) return (x2 - x1) * (y2 - y1);
else return 0.0f;
};
// 计算矩形框的面积
inline float boxArea(const bbox& box){
return box.width * box.height;
};
inline void NMSBoxes(std::vector<bbox> input_boxes, const int conf_threshold, const float nms_threshold, std::vector<bbox> & nms_boxes)
{
std::vector<bbox> sortedBoxes = input_boxes;
printf("confidence=%f, cls=%d\n", sortedBoxes[0].confidence, sortedBoxes[0].cls_id);
std::sort(sortedBoxes.begin(), sortedBoxes.end(), [](bbox a, bbox b) {return a.confidence > b.confidence; });
printf("confidence=%f, cls=%d\n", sortedBoxes[0].confidence, sortedBoxes[0].cls_id);
while(!sortedBoxes.empty()){
// 每次取分数最高的合法锚框
printf("ccccccccccccccc\n");
const bbox currentBox = sortedBoxes.front();
nms_boxes.push_back(currentBox);
sortedBoxes.erase(sortedBoxes.begin()); // 取完后从候选锚框中删除
// 计算剩余锚框与合法锚框的IOU
std::vector<bbox>::iterator it = sortedBoxes.begin();
printf("sortedBoxes=%d, conf=%f\n", sortedBoxes.size(), it[0].confidence);
while (it != sortedBoxes.end()){
const bbox candidateBox = *it; // 取当前候选锚框
if (currentBox.cls_id==candidateBox.cls_id)
{
float intersection = intersectionArea(currentBox, candidateBox); // 计算候选框和合法框的交集面积
float iou = intersection / (boxArea(currentBox) + boxArea(candidateBox) - intersection); // 计算iou
printf("iou=%f\n", iou);
if (iou >= nms_threshold)
{sortedBoxes.erase(it);
printf("delet\n");} // 根据阈值过滤锚框,过滤完it指向下一个锚框
else it++; // 保留当前锚框,判断下一个锚框
}
else
{
it++; // 保留当前锚框,判断下一个锚框
}
}
printf("input_boxes=%d, nms_boxes=%d, sortedBoxes=%d\n", input_boxes.size(), nms_boxes.size(), sortedBoxes.size());
}
return ;
};