前言
有一个同学在群里抱怨,自己不会KD树,并向我发出了直击灵魂的提问:KD树不比谈恋爱难!
KD树真的有这么难吗?工作中也经常用到kd树,也早已熟记于心,接下来我将拿出被窝里讲题的12分耐心,娓娓道来。
一、什么是KD树?
假设你非常牛逼,你有无数个优质基因,而且你们加祖祖辈辈都是两代单传,你每次都是把一半给大儿子,一半给小儿子;你的儿子同样把一半给大儿子,一半给二儿子.....从此一直传下去。提问:你的某个基因在你的哪个后代里面?想要知道答案是不是就顺这个每次给的分界去查找,最后就找到了!
KD树就是一个理想的解题模板,套用就可以理解KD是什么!
我理解的KD树:
(1)KD树是一个二叉树,一个节点包含两个子节点;
(2)KD树的一个节点的切分轴是外包围盒的最长轴和根据一个节点内包含的点集的中位数得到的;
一个KD树的构建过程:
(1)获取节点包围盒及其最长轴
(2)节点排序,并以中位数取切分轴,切分
(3)重复(1)(2)得到整个KD树
二、KD树的重要知识点
上述我们了解了KD树如何构建,接下来具体展开:
1、KD树的作用
KD树主要用于最近邻搜索。在搜索时,从树的根节点开始,沿着树向下移动,根据当前节点的分割维度和值来决定向左子树或右子树移动。在达到叶子节点后,回溯并检查其他子树,直到找到最近的邻居。
KD树在搜索最近邻点时非常有效,因为它通过剪枝操作减少了搜索空间。通过仅在最近邻点的候选集合中搜索,KD树可以减少计算开销。
2、KD树的特征
平衡,构建过程中是中位数保证了树的平衡性,搜索性能极佳!
3、KD树的难点!
插入和删除操作。KD树的插入和删除操作相对较复杂,因为需要保持树的平衡性。通常,插入和删除时需要重新平衡树以维持其性质。最开始上手时,先以查询为主,插入删除后续再出详细文章,目前你先了解为什么难,因为删除点要保证平衡,为什么不必要,我不删除插入就不会涉及树的调整。
三、三维KD树的代码及开源库
1、KD树的开源库
现在有很多开源库有这些代码,而且也不用编译,拿来就可以用:比如FLANN(Fast Library for Approximate Nearest Neighbors),PCL(Point Cloud Library)、CGAL(Computational Geometry Algorithms Library)等,感兴趣的同学可以去下载查看。
2、C++代码
为了简单了解和方便查看,我还是贴一些代码,大家看个乐子。
#include <iostream>
#include <vector>
#include <algorithm>
//点的基本结构
struct Point {
double x, y, z;
Point(double _x, double _y, double _z) : x(_x), y(_y), z(_z) {}
double operator [](int i) const
{
if (i == 0)
return x;
if (i == 1)
return y;
if (i == 2)
return z;
}
};
//点集的包围盒
struct BoundingBox {
Point minPoint;
Point maxPoint;
BoundingBox(const Point& minP, const Point& maxP) : minPoint(minP), maxPoint(maxP) {}
};
BoundingBox calculateBoundingBox(const std::vector<Point>& points) {
if (points.empty()) {
// 返回一个无效的包围盒,表示空点集
return BoundingBox(Point(0, 0, 0), Point(0, 0, 0));
}
double minX = std::numeric_limits<double>::infinity();
double minY = std::numeric_limits<double>::infinity();
double minZ = std::numeric_limits<double>::infinity();
double maxX = -std::numeric_limits<double>::infinity();
double maxY = -std::numeric_limits<double>::infinity();
double maxZ = -std::numeric_limits<double>::infinity();
for (const Point& point : points) {
minX = std::min(minX, point.x);
minY = std::min(minY, point.y);
minZ = std::min(minZ, point.z);
maxX = std::max(maxX, point.x);
maxY = std::max(maxY, point.y);
maxZ = std::max(maxZ, point.z);
}
return BoundingBox(Point(minX, minY, minZ), Point(maxX, maxY, maxZ));
}
//节点结构
struct Node {
Point point;//分割节点
Node* left;//右子树
Node* right;//左子树
Node(Point p) : point(p), left(nullptr), right(nullptr) {}
};
class KDTree {
private:
Node* root;
//输入点集和深度
Node* buildTree(const std::vector<Point>& points, int depth)
{
if (points.empty())
{
return nullptr;
}
int k = 3; // 3维空间
int axis;//切分轴 0 x轴,1 y轴,2 z轴
//计算包围盒
BoundingBox&& box = calculateBoundingBox(points);
double values[3] = {
box.maxPoint[0] - box.minPoint[0],
box.maxPoint[1] - box.minPoint[1],
box.maxPoint[2] - box.minPoint[2],
};
//取最长轴作为切分轴
axis = 0;
if (values[1] > values[axis])
axis = 1;
if (values[2] > values[axis])
axis = 2;
// 中位数排序以选择分割点
if (axis == 0)
{
//从小到大排序
std::sort(points.begin(), points.end(), [](const Point& a, const Point& b)->bool {return a.x < b.x;});
}
else if (axis == 1)
{
std::sort(points.begin(), points.end(), [](const Point& a, const Point& b)->bool {return a.y < b.y;});
}
else
{
std::sort(points.begin(), points.end(), [](const Point& a, const Point& b) ->bool {return a.z < b.z;});
}
//取中点
int medianIndex = points.size() / 2;
Node* node = new Node(points[medianIndex]);
//左 [0,medianIndex)
node->left = buildTree(std::vector<Point>(points.begin(), points.begin() + medianIndex), depth + 1);
//右 [medianIndex,points.size()]
node->right = buildTree(std::vector<Point>(points.begin() + medianIndex + 1, points.end()), depth + 1);
return node;
}
void nearestNeighborSearch(Node* node, const Point& target, Node*& best, double& bestDistance, int depth)
{
if (node == nullptr)
{
return;
}
int k = 3; // 3维空间
int axis = depth % k;
Node* nextBranch = nullptr;
Node* otherBranch = nullptr;
if (axis == 0)
{
nextBranch = (target.x < node->point.x) ? node->left : node->right;
otherBranch = (target.x < node->point.x) ? node->right : node->left;
}
else if (axis == 1)
{
nextBranch = (target.y < node->point.y) ? node->left : node->right;
otherBranch = (target.y < node->point.y) ? node->right : node->left;
}
else
{
nextBranch = (target.z < node->point.z) ? node->left : node->right;
otherBranch = (target.z < node->point.z) ? node->right : node->left;
}
nearestNeighborSearch(nextBranch, target, best, bestDistance, depth + 1);
double distanceToNode = calculateDistance(target, node->point);
if (distanceToNode < bestDistance)
{
best = node;
bestDistance = distanceToNode;
}
// 检查另一侧是否可能存在更近的点
if (std::abs(target[axis] - node->point[axis]) < bestDistance)
{
nearestNeighborSearch(otherBranch, target, best, bestDistance, depth + 1);
}
}
double calculateDistance(const Point& a, const Point& b)
{
return std::sqrt((a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y) + (a.z - b.z) * (a.z - b.z));
}
public:
KDTree(std::vector<Point>& points)
{
root = buildTree(points, 0);
}
Node* findNearestNeighbor(const Point& target)
{
Node* best = nullptr;
double bestDistance = std::numeric_limits<double>::max();
nearestNeighborSearch(root, target, best, bestDistance, 0);
return best;
}
};
int main() {
std::vector<Point> points = { Point(2, 3, 1), Point(5, 4, 8), Point(9, 6, 7), Point(4, 7, 2) };
KDTree kdTree(points);
Point target(7, 5, 6);
Node* nearestNeighbor = kdTree.findNearestNeighbor(target);
std::cout << "Nearest neighbor: (" << nearestNeighbor->point.x << ", "
<< nearestNeighbor->point.y << ", " << nearestNeighbor->point.z << ")" << std::endl;
return 0;
}