KD-Tree 的开源实现与OpenCV中的实现

KD-Tree 开源实现代码
  • 添加头文件 #include <KDTree.hpp>
  • 测试代码及源文件 下载链接(OpenCV version: 4.5.1)
  • hpp 文件名为 “KDTree.hpp”
#ifndef __KDTREE_H__
#define __KDTREE_H__

#include <algorithm>
#include <cmath>
#include <exception>
#include <functional>
#include <numeric>
#include <vector>

namespace obstarcalib::caliblink {
/** @brief k-d tree class.
 */
template <typename PointT, int DIM = 2, typename Allocator = std::allocator<PointT>>
class KDTree {
public:
    /** @brief The constructors.
     */
    KDTree()
        : root_(nullptr) {};
    KDTree(const std::vector<PointT, Allocator>& points)
        : root_(nullptr)
    {
        build(points);
    }

    /** @brief The destructor.
     */
    ~KDTree() { clear(); }

    /** @brief Re-builds k-d tree.
     */
    void build(const std::vector<PointT, Allocator>& points)
    {
        clear();

        points_ = points;

        std::vector<int> indices(points.size());
        std::iota(std::begin(indices), std::end(indices), 0);

        root_ = buildRecursive(indices.data(), (int)points.size(), 0);
    }

    /** @brief Clears k-d tree.
     */
    void clear()
    {
        clearRecursive(root_);
        root_ = nullptr;
        points_.clear();
    }

    /** @brief Validates k-d tree.
     */
    bool validate() const
    {
        try {
            validateRecursive(root_, 0);
        } catch (const Exception&) {
            return false;
        }

        return true;
    }

    /** @brief Searches the nearest neighbor.
     */
    int nnSearch(const PointT& query, double* minDist = nullptr) const
    {
        int guess;
        double _minDist = std::numeric_limits<double>::max();

        nnSearchRecursive(query, root_, &guess, &_minDist);

        if (minDist)
            *minDist = _minDist;
            
        return guess;
    }

    /** @brief Searches k-nearest neighbors.
     */
    std::vector<int> knnSearch(const PointT& query, int k) const
    {
        KnnQueue queue(k);
        knnSearchRecursive(query, root_, queue, k);

        std::vector<int> indices(queue.size());
        for (size_t i = 0; i < queue.size(); i++)
            indices[i] = queue[i].second;

        return indices;
    }

    /** @brief Searches neighbors within radius.
     */
    std::vector<int> radiusSearch(const PointT& query, double radius) const
    {
        std::vector<int> indices;
        radiusSearchRecursive(query, root_, indices, radius);
        return indices;
    }

private:
    /** @brief k-d tree node.
     */
    struct Node {
        int idx; //!< index to the original point
        Node* next[2]; //!< pointers to the child nodes
        int axis; //!< dimension's axis

        Node()
            : idx(-1)
            , axis(-1)
        {
            next[0] = next[1] = nullptr;
        }
    };

    /** @brief k-d tree exception.
        */
    class Exception : public std::exception {
        using std::exception::exception;
    };

    /** @brief Bounded priority queue.
     */
    template <class T, class Compare = std::less<T>>
    class BoundedPriorityQueue {
    public:
        BoundedPriorityQueue() = delete;
        BoundedPriorityQueue(size_t bound)
            : bound_(bound)
        {
            elements_.reserve(bound + 1);
        };

        void push(const T& val)
        {
            auto it = std::find_if(std::begin(elements_), std::end(elements_),
                [&](const T& element) { return Compare()(val, element); });
            elements_.insert(it, val);

            if (elements_.size() > bound_)
                elements_.resize(bound_);
        }

        const T& back() const { return elements_.back(); };
        const T& operator[](size_t index) const { return elements_[index]; }
        size_t size() const { return elements_.size(); }

    private:
        size_t bound_;
        std::vector<T> elements_;
    };

    /** @brief Priority queue of <distance, index> pair.
     */
    using KnnQueue = BoundedPriorityQueue<std::pair<double, int>>;

    /** @brief Builds k-d tree recursively.
     */
    Node* buildRecursive(int* indices, int npoints, int depth)
    {
        if (npoints <= 0)
            return nullptr;

        const int axis = depth % DIM;
        const int mid = (npoints - 1) / 2;

        std::nth_element(indices, indices + mid, indices + npoints, [&](int lhs, int rhs) {
            return points_[lhs][axis] < points_[rhs][axis];
        });

        Node* node = new Node();
        node->idx = indices[mid];
        node->axis = axis;

        node->next[0] = buildRecursive(indices, mid, depth + 1);
        node->next[1] = buildRecursive(indices + mid + 1, npoints - mid - 1, depth + 1);

        return node;
    }

    /** @brief Clears k-d tree recursively.
     */
    void clearRecursive(Node* node)
    {
        if (node == nullptr)
            return;

        if (node->next[0])
            clearRecursive(node->next[0]);

        if (node->next[1])
            clearRecursive(node->next[1]);

        delete node;
    }

    /** @brief Validates k-d tree recursively.
     */
    void validateRecursive(const Node* node, int depth) const
    {
        if (node == nullptr)
            return;

        const int axis = node->axis;
        const Node* node0 = node->next[0];
        const Node* node1 = node->next[1];

        if (node0 && node1) {
            if (points_[node->idx][axis] < points_[node0->idx][axis])
                throw Exception();

            if (points_[node->idx][axis] > points_[node1->idx][axis])
                throw Exception();
        }

        if (node0)
            validateRecursive(node0, depth + 1);

        if (node1)
            validateRecursive(node1, depth + 1);
    }

    static double distance(const PointT& p, const PointT& q)
    {
        double dist = 0;
        for (size_t i = 0; i < DIM; i++)
            dist += (p[i] - q[i]) * (p[i] - q[i]);
        return std::sqrt(dist);
    }

    /** @brief Searches the nearest neighbor recursively.
     */
    void nnSearchRecursive(const PointT& query, const Node* node, int* guess, double* minDist) const
    {
        if (node == nullptr)
            return;

        const PointT& train = points_[node->idx];

        const double dist = distance(query, train);
        if (dist < *minDist) {
            *minDist = dist;
            *guess = node->idx;
        }

        const int axis = node->axis;
        const int dir = query[axis] < train[axis] ? 0 : 1;
        nnSearchRecursive(query, node->next[dir], guess, minDist);

        const double diff = fabs(query[axis] - train[axis]);
        if (diff < *minDist)
            nnSearchRecursive(query, node->next[!dir], guess, minDist);
    }

    /** @brief Searches k-nearest neighbors recursively.
     */
    void knnSearchRecursive(const PointT& query, const Node* node, KnnQueue& queue, int k) const
    {
        if (node == nullptr)
            return;

        const PointT& train = points_[node->idx];

        const double dist = distance(query, train);
        queue.push(std::make_pair(dist, node->idx));

        const int axis = node->axis;
        const int dir = query[axis] < train[axis] ? 0 : 1;
        knnSearchRecursive(query, node->next[dir], queue, k);

        const double diff = fabs(query[axis] - train[axis]);
        if ((int)queue.size() < k || diff < queue.back().first)
            knnSearchRecursive(query, node->next[!dir], queue, k);
    }

    /** @brief Searches neighbors within radius.
     */
    void radiusSearchRecursive(const PointT& query, const Node* node, std::vector<int>& indices, double radius) const
    {
        if (node == nullptr)
            return;

        const PointT& train = points_[node->idx];

        const double dist = distance(query, train);
        if (dist < radius)
            indices.push_back(node->idx);

        const int axis = node->axis;
        const int dir = query[axis] < train[axis] ? 0 : 1;
        radiusSearchRecursive(query, node->next[dir], indices, radius);

        const double diff = fabs(query[axis] - train[axis]);
        if (diff < radius)
            radiusSearchRecursive(query, node->next[!dir], indices, radius);
    }

    Node* root_; //!< root node
    std::vector<PointT, Allocator> points_; //!< points
};
} // kdt

#endif // !__KDTREE_H__

  • 测试代码
#include <iostream>
#include <KDTree.hpp>
#include <opencv2/opencv.hpp>

using namespace std;
using namespace obstarcalib::caliblink;
using namespace cv;

void testMyKDTreeNN() // Searches the nearest neighbor.
{
    // float query[6][2] = {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}};
    float query[6][2];
    for (int i = 0; i < 6; i++)
    {
        query[i][0] = i; query[i][1] = i + 1;
    } 
    vector<float *> v1 = {query, query + 6}; // 存放query每行元素的首地址
    KDTree<float *> kdTree(v1);

    // KDTree 搜索
    float train[2] = {3.6, 7.1}; // 寻找的对象
    double min = 0;
    double *minDist = &min; // 存放匹配对间的距离,sqrt()计算后的结果
    int index = kdTree.nnSearch(train, minDist); // 返回最近邻索引
    cout << index << "---" << *minDist << endl;
}

void testMyKDTreeKNN() // Searches k-nearest neighbors.
{
    // float query[6][2] = {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {4, 5}, {5, 6}};
    float query[6][2];
    for (int i = 0; i < 6; i++)
    {
        query[i][0] = i; query[i][1] = i + 1;
    } 
    vector<float *> v1 = {query, query + 6}; // 存放query每行元素的首地址
    KDTree<float *> kdTree(v1);

    // KDTree 搜索
    float train[2] = {3.6, 7.1}; // 寻找的对象
    int k = 2;
    vector<int> index = kdTree.knnSearch(train, k); // 返回k个最近邻索引
    cout << index[0] << endl;
}

void testOpenCVKDTreeKNN() // Searches k-nearest neighbors by OpenCV
{
    // 用于构造kdtree的点集
    vector<Point2f> features;
    for (int i = 0; i < 6; i++)
        features.push_back({i, i + 1});
    Mat source = Mat(features).reshape(1); // 将vector<Point2f>转换为Mat类型,并将通道数设为1
    source.convertTo(source, CV_32F); // 转换为浮点型

    cv::flann::KDTreeIndexParams indexParams(1); // 构建KDTree的数量
    cv::flann::Index kdtree(source, indexParams);

    // 预设knnSearch所需参数及容器
    int queryNum = 2;                   //用于设置返回邻近点的个数
    vector<float> vecQuery(2);          //存放查询点的容器
    vector<int> vecIndex(queryNum);     //存放返回的点索引
    vector<float> vecDist(queryNum);    //存放距离
    cv::flann::SearchParams params(32); //设置knnSearch搜索参数

    // KDTree 搜索
    vecQuery = {3.6, 7.1};
    kdtree.knnSearch(vecQuery, vecIndex, vecDist, queryNum, params); // 距离输出为平方
    cout << vecIndex[0] << "---" << sqrt(vecDist[0]) << endl;
}

int main(int argc, char **argv)
{
    testMyKDTreeNN();
    testMyKDTreeKNN();
    testOpenCVKDTreeKNN();
    system("pause");
    return 0;
}
总结
  • 开源 KD-Tree 实现代码数据结构必须可以使用 [ ] 的访问方式,并且knnSearch()接口没有返回匹配点距离,可自行修改源码
  • OpenCV KD-Tree 的 KNN(K-Nearest Neighbors) 匹配速度优于开源 KD-Tree,但获得 NN (Nearest Neighbor),使用开源 KD-Tree 更快
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值