C++ 版 K-D 树(kdtree)的实现——源码解析

C++ 版 K-D 树的实现——源码解析

point 类

point 类存储了单个的点,仅有一个成员变量 cooeds_,类型为 std::array,模板参数 coordinate_type 为点的类型,可以为 intdouble 等,dimensions 为点的维度,若一个点包含三个方向(X、Y、Z)的坐标,那么维度就为 3。

point 类包含两个成员函数,get 函数获取该点某一维度上的数值,distance 函数计算两个点直接的距离,除此之外还有两个构造函数。

point 类源码:

template<typename coordinate_type, size_t dimensions>
class point {
public:
    point(std::array<coordinate_type, dimensions> c) : coords_(c) {}
    point(std::initializer_list<coordinate_type> list) {
        size_t n = std::min(dimensions, list.size());
        std::copy_n(list.begin(), n, coords_.begin());
    }

    coordinate_type get(size_t index) const {
        return coords_[index];
    }

    double distance(const point& pt) const {
        double dist = 0;
        for (size_t i = 0; i < dimensions; ++i) {
            double d = get(i) - pt.get(i);
            dist += d * d;
        }
        return dist;
    }
private:
    std::array<coordinate_type, dimensions> coords_;
};

point 类重载 << 操作符

template<typename coordinate_type, size_t dimensions>
std::ostream& operator<<(std::ostream& out, const point<coordinate_type, dimensions>& pt) {
    out << '(';
    for (size_t i = 0; i < dimensions; ++i) {
        if (i > 0)
            out << ", ";
        out << pt.get(i);
    }
    out << ')';
    return out;
}

KD-tree 类

kdtree源码

template<typename coordinate_type, size_t dimensions>
class kdtree {
public:
    typedef point<coordinate_type, dimensions> point_type;
private:
    struct node {
        node(const point_type& pt) : point_(pt), left_(nullptr), right_(nullptr) {}
        coordinate_type get(size_t index) const {
            return point_.get(index);
        }
        double distance(const point_type& pt) const {
            return point_.distance(pt);
        }
        point_type point_;
        node* left_;
        node* right_;
    };
    node* root_ = nullptr;
    node* best_ = nullptr;
    double best_dist_ = 0;
    size_t visited_ = 0;
    std::vector<node> nodes_;
 
    struct node_cmp {
        node_cmp(size_t index) : index_(index) {}
        bool operator()(const node& n1, const node& n2) const {
            return n1.point_.get(index_) < n2.point_.get(index_);
        }
        size_t index_;
    };
 
    node* make_tree(size_t begin, size_t end, size_t index) {
        if (end <= begin)
            return nullptr;
        size_t n = begin + (end - begin)/2;
        auto i = nodes_.begin();
        std::nth_element(i + begin, i + n, i + end, node_cmp(index));
        index = (index + 1) % dimensions;
        nodes_[n].left_ = make_tree(begin, n, index);
        nodes_[n].right_ = make_tree(n + 1, end, index);
        return &nodes_[n];
    }
 
    void nearest(node* root, const point_type& point, size_t index) {
        if (root == nullptr)
            return;
        ++visited_;
        double d = root->distance(point);
        if (best_ == nullptr || d < best_dist_) {
            best_dist_ = d;
            best_ = root;
        }
        if (best_dist_ == 0)
            return;
        double dx = root->get(index) - point.get(index);
        index = (index + 1) % dimensions;
        nearest(dx > 0 ? root->left_ : root->right_, point, index);
        if (dx * dx >= best_dist_)
            return;
        nearest(dx > 0 ? root->right_ : root->left_, point, index);
    }
public:
    kdtree(const kdtree&) = delete;
    kdtree& operator=(const kdtree&) = delete;

    template<typename iterator>
    kdtree(iterator begin, iterator end) : nodes_(begin, end) {
        root_ = make_tree(0, nodes_.size(), 0);
    }
 
    template<typename func>
    kdtree(func&& f, size_t n) {
        nodes_.reserve(n);
        for (size_t i = 0; i < n; ++i)
            nodes_.push_back(f());
        root_ = make_tree(0, nodes_.size(), 0);
    }

    bool empty() const { return nodes_.empty(); }
 
    size_t visited() const { return visited_; }

    double distance() const { return std::sqrt(best_dist_); }
 
    const point_type& nearest(const point_type& pt) {
        if (root_ == nullptr)
            throw std::logic_error("tree is empty");
        best_ = nullptr;
        visited_ = 0;
        best_dist_ = 0;
        nearest(root_, pt, 0);
        return best_->point_;
    }
};

node 结构体

在 kdtree 内部创建了 node 的结构体,该结构体内部包含了一个点,以及指向左右节点的两个指针,形成了二叉树的结构。

同时,node_cmp 用于比较两个点在某一维度上的大小。

kdtree 成员变量

  • node* root_ = nullptr; :根节点指针
  • node* best_ = nullptr; :查找到的节点
  • double best_dist_ = 0; :查找到的节点与目标节点之间的距离
  • size_t visited_ = 0; :查找目标点时的访问次数
  • std::vector<node> nodes_; :存储 nodevector 容器

kdtree 两个重要的成员函数

构建 kdtree

构建树的代码采用递归思想:

  1. 获取中间位置的元素;
  2. 对中间位置元素进行排序,只让中间元素在正确的排序位置,其余元素位置不管;
  3. 切换维度,树的下一层通过另一个维度排序;
  4. 让中间元素作为该递归层的根节点,根节点连接左、右节点;
  5. 递归的构建左右两侧的子树;
  6. 当节点为空时返回;
  7. 最终返回根节点。

构建树的函数是在构造函数内完成的,即在输入点云完成后就构建了 kdtree。

查找最近邻

查找函数 nearest 也是采用的递归:

  1. 首先计算输入节点(参数一)与目标节点(参数二)之间的距离,若输入节点为空或距离小于 best_dist_,更新 best_dist_best_
  2. 若存在 best_dist_ 为 0,函数直接返回,表示找到了最近的点;
  3. 计算两节点在输入维度上的距离;
  4. 切换维度;
  5. 判断某一维度上的距离差值大小,小于 0 再去左子树中查找,反之去右子树中查找,若计算得到某一维度的值直接大于 best_dist_,直接返回,因为再往下面去查找没有意义了,因为在这一维度上,它们之间的距离越来越大。
  6. 节点为空时,返回返回。

在类中 public 作用域下,重载了 nearest 函数,该函数返回查找到的节点指针指向的最佳点。

kdtree 其他成员函数

  • bool empty() const { return nodes_.empty(); } 点是否为空
  • size_t visited() const { return visited_; } 查找最近点所用的访问次数
  • double distance() const { return std::sqrt(best_dist_); } 返回与查找点的距离

测试 KD-tree

void test_wikipedia() {
    typedef point<int, 2> point2d;
    typedef kdtree<int, 2> tree2d;
 
    point2d points[] = { { 2, 3 }, { 5, 4 }, { 9, 6 }, { 4, 7 }, { 8, 1 }, { 7, 2 } };
 
    tree2d tree(std::begin(points), std::end(points));
    point2d n = tree.nearest({ 9, 2 });
 
    std::cout << "Wikipedia example data:\n";
    std::cout << "nearest point: " << n << '\n';
    std::cout << "distance: " << tree.distance() << '\n';
    std::cout << "nodes visited: " << tree.visited() << '\n';
}
 
typedef point<double, 3> point3d;
typedef kdtree<double, 3> tree3d;
 
struct random_point_generator {
    random_point_generator(double min, double max)
        : engine_(std::random_device()()), distribution_(min, max) {}
 
    point3d operator()() {
        double x = distribution_(engine_);
        double y = distribution_(engine_);
        double z = distribution_(engine_);
        return point3d({x, y, z});
    }
 
    std::mt19937 engine_;
    std::uniform_real_distribution<double> distribution_;
};
 
void test_random(size_t count) {
    random_point_generator rpg(0, 1);
    tree3d tree(rpg, count);
    point3d pt(rpg());
    point3d n = tree.nearest(pt);
 
    std::cout << "Random data (" << count << " points):\n";
    std::cout << "point: " << pt << '\n';
    std::cout << "nearest point: " << n << '\n';
    std::cout << "distance: " << tree.distance() << '\n';
    std::cout << "nodes visited: " << tree.visited() << '\n';
}
 
int main() 
{
    try 
    {
        test_wikipedia();
        std::cout << '\n';
        test_random(1000);
        std::cout << '\n';
        test_random(1000000);
    } 
    catch (const std::exception& e) 
    {
        std::cerr << e.what() << '\n';
    }
    return 0;
}

结果输出

Wikipedia example data:
nearest point: (8, 1)
distance: 1.41421
nodes visited: 3

Random data (1000 points):
point: (0.740311, 0.290258, 0.832057)
nearest point: (0.761247, 0.294663, 0.83404)
distance: 0.0214867
nodes visited: 15

Random data (1000000 points):
point: (0.646712, 0.555327, 0.596551)
nearest point: (0.642795, 0.552513, 0.599618)
distance: 0.00571496
nodes visited: 46

包含的头文件

上述代码包含的头文件如下,将这些代码整理到一起就可以测试了。

#include <algorithm>
#include <array>
#include <cmath>
#include <iostream>
#include <random>
#include <vector>
  • 9
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值