一、kd-tree简介
kd-tree(全称为k-dimensional tree),它是一种分割k维数据空间的点,并进行存储的数据结构;在计算机科学里,kd-tree是在k维欧几里德空间组织点的数据结构。Kd-tree是二进制空间分割数的特殊情况,常应用于多为空间关键数据的搜索,例如范围搜索和最近邻搜索。
kd-tree的每个节点都是k维点的二叉树。所有的非叶子节点可以看做为将一个空间分割成两个半空间的超平面。节点左边的子树代表在超平面左边的点(即在分割的维度上小于超平面的点集合),节点右边的子树代表在超平面右边的点(即在分割的维度上大于超平面的点集合)。
二、kd-tree的实现过程
kd-tree在实现的时候,主要包含两部分:
1) kd-tree的创建;
2) 最近邻搜索。
1. kd-tree的创建
在kd-tree中,每个节点表示一个空间范围,则节点的数据结构如下所示:
名称 | 数据类型 | 描述 |
---|---|---|
point | 数据矢量 | 数据集中某一个数据点,是k维矢量 |
range | 空间矢量 | 该节点所表示的空间范围 |
split | 整数 | 垂直于分割超平面的方向轴序号 |
left | Node | 左半空间,即左节点 |
right | Node | 右半空间,即右节点 |
parent | Node | 当前节点的父节点 |
建立kd-tree时,需要遵循两个准则:
1) kd-tree应当尽量平衡,树越平衡代表着分割越平均,搜索时间就越少;
2) 最大化邻域搜索的剪枝机会。
在确定split域时,我们采用各维度方差大的维度作为split域的值,主要是因为,数据指定维度的方差值越大,表示在该维度上数据点分散的比较开,更易于达到最好的平衡。确定split的域值后,将数据点集按照split的维度值进行排序,并选择中间点作为轴点。在这里面提到的split域值、轴点均指上文提到的超平面。
2. 最近邻搜索
最近邻搜索主要是指在kd-tree中找出与输入点最近的点。其过程如下所示:
1) 从根节点开始,递归下降。
2) 当下降到叶节点时,默认将该节点作为“当前最佳点“。计算当前点和目标点之间的距离,采用欧式距离计算:
d
(
x
,
y
)
=
(
x
1
−
y
1
)
2
+
(
x
2
−
y
2
)
2
+
.
.
.
+
(
x
n
−
y
n
)
2
=
∑
i
=
1
N
(
x
1
−
y
1
)
2
d(x, y)=\sqrt{(x_1 - y_1)^2 + (x_2 - y_2)^2 + ... + (x_n - y_n)^2} = \sqrt{\sum_{i=1}^N {(x_1 - y_1)^2}}
d(x,y)=(x1−y1)2+(x2−y2)2+...+(xn−yn)2=i=1∑N(x1−y1)2。
3) 向前回溯,尅是对每个经过的节点的进行下面操作:a. 若当前节点比当前最佳点更加靠近输入点,则将其变为当前最佳点。b. 检查另一边子树有没有更接近的点。如果有则从该几点往下找。
4) 当前节点搜索完毕后,即完成最邻近搜索。
三、kd-tree代码实现
// kd_tree.h
#include <iostream>
#include <algorithm>
#include <math.h>
#include <stack>
#include <vector>
namespace alg {
using namespace std;
class KdTree {
public:
using DataV = std::vector<float>;
using DataVS = std::vector<DataV >;
struct Node{
DataV point_;
int split_;
Node *left_;
Node *right_;
};
enum COMPARE_TYPE{
SPECIFY_DIM,
ALL
};
static bool compare(DataV& a, DataV& b);
bool equal(DataV& a, DataV& b);
void choose_split(DataVS& points, int &split, DataV& split_choice);
Node* build_tree(DataVS& points, Node *node);
float distance(DataV& a, DataV& b);
void search_nearest(Node *node, DataV& target, DataV& nearst_point, double& distance);
private:
static int compare_dim_;
};
}
//kd_tree.cpp
#include "kd_tree.h"
namespace alg{
int KdTree::compare_dim_ = 0;
bool KdTree::compare(alg::KdTree::DataV &a, alg::KdTree::DataV &b) {
return a.at(compare_dim_) < b.at(compare_dim_);
}
bool KdTree::equal(alg::KdTree::DataV &a, alg::KdTree::DataV &b) {
if (a.size() != b.size()){
return false;
}
for(int i = 0; i < a.size(); ++i){
if (a.at(i) != b.at(i))
return false;
}
return true;
}
void KdTree::choose_split(alg::KdTree::DataVS& points, int &split, alg::KdTree::DataV &split_choice) {
int dim = points[0].size();
int size = points.size();
DataV pow_2_v(dim, 0.0f), sum_v(dim, 0.0f), variance(dim, 0.0f);
split_choice.resize(dim);
for (auto point: points){
for(int d = 0; d < dim; ++d){
pow_2_v[d] = pow_2_v[d] + pow(point.at(d), 2);
sum_v[d] = sum_v[d] + point.at(d);
}
}
for (int d = 0; d < dim; ++d){
variance[d] = (pow_2_v[d] - pow(sum_v[d], 2))/(size * 1.0f);
}
compare_dim_ = 0;
for (int d = 1; d < dim; ++d){
if (variance[d - 1] < variance[d])
compare_dim_ = d;
}
sort(points.begin(), points.end(), compare);
split = compare_dim_;
for(int d = 0; d < dim; ++d){
split_choice[d] = points[size / 2].at(d);
}
}
KdTree::Node* KdTree::build_tree(alg::KdTree::DataVS& points, alg::KdTree::Node *node) {
if (points.empty())
return nullptr;
int split;
DataV hyperplanes;
this->choose_split(points, split, hyperplanes);
alg::KdTree::DataVS left_points, right_points;
for(auto point: points){
if(!equal(point, hyperplanes) && point[split] <= hyperplanes[split]){
left_points.push_back(point);
} else if (!equal(point, hyperplanes) && point[split] > hyperplanes[split]){
right_points.push_back(point);
}
}
node = new Node;
node->point_ = std::move(hyperplanes);
node->split_ = split;
node->left_ = this->build_tree(left_points, node->left_);
node->right_ = this->build_tree(right_points, node->right_);
return node;
}
float KdTree::distance(alg::KdTree::DataV &a, alg::KdTree::DataV &b) {
if (a.size() != b.size())
return -1;
float tmp = 0.0f;
for (int i = 0; i < a.size(); i++){
tmp += pow(a.at(i) - b.at(i), 2);
}
return sqrt(tmp);
}
void KdTree::search_nearest(alg::KdTree::Node *node, alg::KdTree::DataV &target, alg::KdTree::DataV &nearst_point,
double &distance) {
stack<Node*> search_points;
Node *pSearch = node;
DataV nearest;
double dist_;
while (pSearch != nullptr){
search_points.push(pSearch);
if ( target.at(pSearch->split_) <= pSearch->point_.at(pSearch->split_)){
pSearch = pSearch->left_;
} else {
pSearch = pSearch->right_;
}
}
nearest = std::move(search_points.top()->point_);
search_points.pop();
dist_ = this->distance(nearest, target);
Node *pBack;
while(search_points.empty()){
// todo
pBack = search_points.top();
search_points.pop();
if (pBack->left_ == nullptr && pBack->right_ == nullptr) {
double tmp_dist = this->distance(pBack->point_, target);
if (dist_ > tmp_dist){
nearest = std::move(pBack->point_);
dist_ = tmp_dist;
}
} else {
int tmp_split = pBack->split_;
if ( fabs(pBack->point_.at(tmp_split) - target.at(pBack->split_)) < dist_ ){
double tmp_dist = this->distance(pBack->point_, target);
if (dist_ > tmp_dist){
nearest = std::move(pBack->point_);
dist_ = tmp_dist;
}
}
if ( target.at(tmp_split) <= pBack->point_.at(tmp_split)){
pSearch = pBack->right_;
} else {
pSearch = pBack->left_;
}
if (nullptr != pSearch){
search_points.push(pSearch);
}
}
}
nearst_point = std::move(nearest);
distance = dist_;
}
} // namespace alg
// kd_tree_test.cpp
#include "kd_tree.h"
using namespace alg;
int main(int argc, char** argv){
KdTree::DataVS input = {{2.f, 3.f}, {5.f, 4.f}, {9.f, 6.f}, {8.f, 7.f}, {7.f, 2.f}};
KdTree::DataV target = {2.1f, 3.1f};
KdTree tree;
KdTree::Node* root = nullptr;
root = tree.build_tree(input, root);
KdTree::DataV nearest_point;
double dist;
tree.search_nearest(root, target, nearest_point, dist);
return 0;
}