非极大值抑制(NMS)仿真代码:来自论文Efficient Non-Maximum Suppression

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wfh2015/article/details/79925188

非极大值抑制(Efficient Non-Maximum Suppression)

这里说的非极大值抑制与平时说的非极大值抑制有点不太一样, 平时说的NMS指的是在检测算法中的应用,将多个矩形框进行筛选,筛选出所需要的矩形框。而本文所说的NMS主要是来自一篇论文Neubeck A, Gool L V. Efficient Non-Maximum Suppression[C]// International Conference on Pattern Recognition. IEEE Computer Society, 2006:850-855.。如果代码仿真有问题,还请大家原谅!

一维3连通域NMS(1D NMS for 3-Neighborhood)

#include <vector>
#include <algorithm>
#include <time.h>
#include "opencv2/opencv.hpp"

/** @brief 对3领域情况的数据非极大值抑制算法并获取其最大值的索引

@param data int类型数据, 可在代码中适当的修改
@param index 对获得的局部最大值索引
*/
static void nms(const std::vector<int>& data, std::vector<int>& index)
{
    index.clear();

    // 对算法不能处理的数据进行处理
    if (data.size() == 1) index.push_back(0);
    else if (data.size() == 2)
    {
        int idx = data[1] > data[0] ? 1 : 0;
        index.push_back(idx);
        return;
    }
    else if (data.size() <= 0) return;
    else {}

    /////////////////////////data.size()>=3处理//////////////////////////
    int i = 1;
    while (i <= data.size()-2)
    {
        if (data[i] > data[i + 1])
        {
            if (data[i] >= data[i - 1]) index.push_back(i);
        }
        else
        {
            ++i;
            while ((i <= data.size() - 2) && (data[i] <= data[i + 1])) ++i;
            if (i <= data.size() - 2) index.push_back(i);
        }
        i = i + 2;

    } // End Outer while
}


/** @brief 验证写的非极大值抑制算法

@param data 原始数据
@param index 获取的极大值索引
*/
static bool VerifyNms(std::vector<int>& data, std::vector<int>& index)
{
    for (int i = 0; i < index.size(); ++i)
    {
        int idx = index[i];
        if ((data[idx] >= data[idx - 1]) && (data[idx] > data[idx + 1])) continue;
        return false;
    }

    return true;
}

/** @brief 随机产生vector类型数据, 数据个数为num个

@param data 用来接收所产生的数据
@param num data数据的个数
*/
static void CreateVec(std::vector<int>& data, int num)
{
    data.clear();
    srand(unsigned(time(NULL)));
    for (int i = 0; i < num; ++i)
    {
        data.push_back(i);
    }
    std::random_shuffle(data.begin(), data.end());
}

// 测试3领域非极大值抑制
void test_Nms3Field()
{
    std::vector<int> data;
    std::vector<int> index;

    const int times = 1000000;
    for (int i = 0; i < times; ++i)
    {
        std::cout << i << ": ";
        CreateVec(data, 8);
        std::cout << "create data !";

        nms(data, index);
        std::cout << "\tIndexNum:" << index.size();

        bool flag = VerifyNms(data, index);
        std::cout << "\tverify:";
        std::cout << ((flag == true) ? " " : "错误") << std::endl;
    }

    //cv::waitKey(1);
}

一维多领域(1D NMS for (2n + 1)-Neighborhood)

#include <vector>
#include <algorithm>
#include <time.h>
#include "opencv2/opencv.hpp"

/** @brief 对3领域情况的数据非极大值抑制算法并获取其最大值的索引

@param data int类型数据, 可在代码中适当的修改
@param index 对获得的局部最大值索引
*/
static void nms(const std::vector<int>& data, std::vector<int>& index)
{
    index.clear();

    // 对算法不能处理的数据进行处理
    if (data.size() == 1) index.push_back(0);
    else if (data.size() == 2)
    {
        int idx = data[1] > data[0] ? 1 : 0;
        index.push_back(idx);
        return;
    }
    else if (data.size() <= 0) return;
    else {}

    /////////////////////////data.size()>=3处理//////////////////////////
    int i = 1;
    while (i <= data.size()-2)
    {
        if (data[i] > data[i + 1])
        {
            if (data[i] >= data[i - 1]) index.push_back(i);
        }
        else
        {
            ++i;
            while ((i <= data.size() - 2) && (data[i] <= data[i + 1])) ++i;
            if (i <= data.size() - 2) index.push_back(i);
        }
        i = i + 2;

    } // End Outer while
}


/** @brief 验证写的非极大值抑制算法

@param data 原始数据
@param index 获取的极大值索引
*/
static bool VerifyNms(std::vector<int>& data, std::vector<int>& index)
{
    for (int i = 0; i < index.size(); ++i)
    {
        int idx = index[i];
        if ((data[idx] >= data[idx - 1]) && (data[idx] > data[idx + 1])) continue;
        return false;
    }

    return true;
}

/** @brief 随机产生vector类型数据, 数据个数为num个

@param data 用来接收所产生的数据
@param num data数据的个数
*/
static void CreateVec(std::vector<int>& data, int num)
{
    data.clear();
    srand(unsigned(time(NULL)));
    for (int i = 0; i < num; ++i)
    {
        data.push_back(i);
    }
    std::random_shuffle(data.begin(), data.end());
}

// 测试3领域非极大值抑制
void test_Nms3Field()
{
    std::vector<int> data;
    std::vector<int> index;

    const int times = 1000000;
    for (int i = 0; i < times; ++i)
    {
        std::cout << i << ": ";
        CreateVec(data, 8);
        std::cout << "create data !";

        nms(data, index);
        std::cout << "\tIndexNum:" << index.size();

        bool flag = VerifyNms(data, index);
        std::cout << "\tverify:";
        std::cout << ((flag == true) ? "正确" : "错误") << std::endl;
    }

    //cv::waitKey(1);
}

二维多领域NMS(2D (n + 1) × (n + 1)-Block NMS)

#include <vector>
#include <cassert>
#include "opencv2/opencv.hpp"

/** @brief 2D非极大值抑制算法-主要是针对图像

@param bin 单通道矩阵
@param index 获取的极大值坐标的映射图
@param n 领域数目(2n+1)*(2n+1)
*/
static void nms_2D(cv::Mat& bin, cv::Mat& index, int n)
{
    CV_Assert(bin.type() == CV_8UC1);
    assert((n % 2 == 1) && (n > 0));
    assert((bin.rows >= n) && (bin.cols >= n));

    index = cv::Mat::zeros(bin.rows, bin.cols, CV_8UC1);
    n = n / 2;

    for (int y = n; y <= bin.rows - n - 1; ++y)
    {
        for (int x = n; x <= bin.cols - n - 1; ++x)
        {
            uchar v = bin.at<uchar>(y, x);
            cv::Point m = cv::Point(x, y);

            for (int y2 = y; y2 <= y + n; ++y2)
            {
                for (int x2 = x; x2 <= x + n; ++x2)
                {
                    if ((x2 < 0) || (x2 >= bin.cols) || (y2 < 0) || (y2 >= bin.rows)) continue;
                    if (bin.at<uchar>(y2, x2) > bin.at<uchar>(m.y, m.x))    m = cv::Point(x2, y2);
                } // End x2
            } //End y2

            for (int y2 = m.y - n; y2 <= m.y + n; ++y2)
            {
                for (int x2 = m.x - n; x2 <= m.x + n; ++x2)
                {
                    if ((x2 < 0) || (x2 >= bin.cols) || (y2 < 0) || (y2 >= bin.rows)) continue;
                    if (bin.at<uchar>(y2, x2) > bin.at<uchar>(m.y, m.x))
                    {
                        goto failed;    // goto failed
                    }
                } // End x2
            } // End y2
            index.at<uchar>(m.y, m.x) = 255;
        failed:
            ;
        } // End x
    } // End y
    //cv::waitKey(1);
}


void test_nms_2D()
{

    std::string path = "../Resources/wechat_20180409161327.bmp";
    cv::Mat img = cv::imread(path, cv::IMREAD_GRAYSCALE);
    if (!img.data) return;
    img = ~img;

    cv::Mat index;
    nms_2D(img, index, 5);

    cv::imshow("图像原图", img);
    cv::imshow("找出来的结果", index);
    cv::waitKey(0);
}
展开阅读全文

没有更多推荐了,返回首页