KD-Tree构建及邻近点搜索 C++实现
1、KD-Tree构建原理
KD-Tree是一种树形数据结构,常用于点云处理。KD-Tree会在其非叶结点上对数据进行划分,在叶子结点上存储数据。
结点数据结构如下:
//KD树结点
struct TreeNode
{
//数据维度,如三维点,二维点
int dim = 1;
//当前结点的分割维度,表示在那个维度进行分割
int split_dim = -1;
//当前结点所有叶子结点存储数据的大小
int data_size = 0;
//是否为左叶子结点,false表示当前结点是父结点的右结点
bool is_left = true;
//孩子结点
TreeNode* chirld[2];
//父结点
TreeNode* father;
//当前结点所有数据的平均值
double* split_center;
//当前结点所存储数据的索引,非叶子结点会在数据分割后清空内存
vector<int> data;
//有参构造,in_dim为输入的维度
TreeNode(int in_dim)
{
dim = in_dim;
chirld[0] = nullptr;
chirld[1] = nullptr;
father = nullptr;
split_center = new double[dim];
memset(split_center, 0, dim * sizeof(double));
}
};
通常情况下,KD-Tree的构建过程有三步:
(1)计算最佳的划分维度以及当前待划分数据的平均值,通常以当前结点待划分数据的最大方差维度作为划分轴。
(2)根据划分轴及数据平均值对数据进行划分。
(3)递归左孩子,右孩子,直到所有数据划分完成。
当数据量特别大时,使用递归创建的方式将会使存储函数的栈空间都被使用完,后续在创建函数会弹出异常。
void KDTreeSearch::GetDataFeatrue(
const vector<vector<double>>& data,
const vector<int>& index,
int dim,
double* center,
int& split_dim)
{
const int data_size = index.size();
vector<double> average(dim, 0.);
vector<double> standard(dim, 0.);
for (int i = 0; i < data_size; i++)
{
int k = index[i];
const auto& p = data[k];
for (int j = 0; j < dim; j++)
{
average[j] += p[j];
}
}
for (int j = 0; j < dim; j++)
{
average[j] /= data_size;
}
memcpy(center, average.data(), dim * sizeof(double));
for (int i = 0; i < data_size; i++)
{
int k = index[i];
const auto& p = data[k];
for (int j = 0; j < dim; j++)
{
standard[j] += pow(p[j] - average[j], 2);
}
}
for (int j = 0; j < dim; j++)
{
standard[j] = sqrt(standard[j] / data_size);
}
split_dim = 0;
double max_val = standard[split_dim];
for (int i = 1; i < dim; i++)
{
if (max_val < standard[i])
{
max_val = standard[i];
split_dim = i;
}
}
}
void KDTreeSearch::CreateKDTree(const vector<vector<double>>& data, TreeNode* head)
{
//当前结点只有一个数据时,退出
if (head->data_size <= 1) return;
//1.获取分割维度以及平均值
GetDataFeatrue(data, head->data, head->dim, head->split_center, head->split_dim);
//2.构造左右子树
for (int i = 0; i < 2; i++)
{
head->chirld[i] = new TreeNode(head->dim);
head->chirld[i]->is_left = 1 - i; //表示当前结点是否为父结点的左孩子
head->chirld[i]->father = head; //将孩子结点指向父结点
}
//3.根据分割维度对数据进行分割
for (int i = 0; i < head->data_size; i++)
{
int k = head->data[i];
const auto& p = data[k];
if (p[head->split_dim] < head->split_center[head->split_dim])
{
//小于分割值,加入左孩子
head->chirld[0]->data.push_back(k);
head->chirld[0]->data_size++;
}
else
{
//大于分割值,加入右孩子
head->chirld[1]->data.push_back(k);
head->chirld[1]->data_size++;
}
}
//4.清空当前结点存储的数据
vector<int>().swap(head->data);
//5.递归
for (int i = 0; i < 2; i++)
{
CreateKDTree(data, head->chirld[i]);
}
}
2、KD-Tree的遍历与删除
KD-Tree的遍历与二叉树的遍历相同,可以使用先序、中序、后序等遍历方式,但KD-Tree只有叶子结点存储真正的数据,所以非叶结点直接跳过即可。下面为KD-Tree的后续遍历方式。
void KDTreeSearch::ErgodicKDTree(TreeNode* head, vector<int>& index)
{
if (head == nullptr) return;
//后续遍历叶子结点所有元素
ErgodicKDTree(head->chirld[0], index);
ErgodicKDTree(head->chirld[1], index);
if (head->chirld[0] == nullptr)
index.insert(index.end(), head->data.begin(), head->data.end());
}
如果在函数中使用非智能指针等方式申请内存,若在退出函数时没有释放,将会导致内存泄漏。一般来说,如果程序执行该函数后立即杀死程序,后续不做任何操作,内存泄漏基本没有影响;当一个程序频繁调用该函数或者在该函数后进行其他操作,前者会导致内存消耗越来越大,后者可能污染内存。所以KD-Tree的删除不可避免,也是后续遍历释放内存的方式。
void KDTreeSearch::DeleteKDTree(TreeNode* head)
{
if (head == nullptr) return;
DeleteKDTree(head->chirld[0]);
DeleteKDTree(head->chirld[1]);
int k = head->is_left ? 0 : 1;
if (head->father)
head->father->chirld[k] = nullptr;
delete[] head->split_center; head->split_center = nullptr;
delete head; head = nullptr;
}
3、基于KD-Tree的近邻点搜索
通常情况下,近邻点搜索有两种方式——K近邻搜索与半径搜索。近邻点搜索一般采用先找最近点,再回溯搜索父结点其他子树中的结点。
K近邻搜索,即搜索当前点最近的K个点。这里使用一个偷懒的方式实现,主要步骤如下:
(1)当K大于当前结点的点数并且小于当前结点父结点的点数时,将父结点的所有叶子结点的点传出。
(2)计算当前点到搜索到的点的欧式距离。
(3)根据欧式距离进行升序排序,取前K个点。
void KDTreeSearch::KnnSearchKDTree(
TreeNode* head,
const vector<double>& p,
int K,
vector<int>& index)
{
if (head == nullptr) return;
//当前结点数据量大小小于K
if (head->data_size < K)
{
//思想:若K大于当前结点数据量,则K必小于当前结点父结点的数据量,根结点除外
//若当结点为根结点,则从根结点遍历,否则从当前结点的父结点遍历
TreeNode* node = head->father == nullptr ? head : head->father;
ErgodicKDTree(node, index);
return;
}
//根据当前结点的分割维度以及分割值,决定在哪个子树中搜索邻近点
if (p[head->split_dim] < head->split_center[head->split_dim])
{
KnnSearchKDTree(head->chirld[0], p, K, index);
}
else
{
KnnSearchKDTree(head->chirld[1], p, K, index);
}
}
void KDTreeSearch::KnnSearch(
const vector<double>& p,
int K,
vector<int>& index,
vector<double>& dist)
{
KnnSearchKDTree(m_head, p, K, index);
vector<pair<int, double>> index_dist(index.size());
for (size_t i = 0; i < index.size(); i++)
{
int t = index[i];
const auto& p1 = m_tree_data[t];
index_dist[i].first = t;
index_dist[i].second = Distance(p.data(), p1.data(), m_dim);
}
sort(index_dist.begin(), index_dist.end(), [](pair<int, double> p1, pair<int, double> p2)->bool {return p1.second < p2.second; });
K = K > index_dist.size() ? index_dist.size() : K;
index.resize(K);
dist.resize(K);
for (int i = 0; i < K; i++)
{
index[i] = index_dist[i].first;
dist[i] = index_dist[i].second;
}
}
半径搜索,即从KD-Tree中搜索与当前点的欧式距离小于给定半径的所有点。主要步骤如下:
(1)搜索离当前点最近的叶子结点。
(2)从当前结点回溯,遍历父节点其他孩子节点的所有点,计算到当前点的欧式距离,并根据该距离插入到链表中。
(3)当父节点其他节点的点到当前点的欧式距离都大于搜索半径时,退出。
//用于半径搜索,回溯时将小于半径的数据插入到链表中
struct SearchList
{
//索引
int ind = -1;
//距离
double dist = 0.;
//下一个结点
SearchList* next = nullptr;
};
void KDTreeSearch::RadiusSearch(
const vector<double>& p,
double r,
vector<int>& index,
vector<double>& dist)
{
//1.根据半径搜索邻近点,将搜索结果放入链表中
SearchList* ls = new SearchList;
RadiusSearchKDTree(m_tree_data, m_head, p, r, ls);
//2.输出搜索结果
index.clear();
dist.clear();
SearchList* node = ls->next;
while (node != nullptr)
{
index.push_back(node->ind);
dist.push_back(node->dist);
node = node->next;
}
//3.删除链表
node = ls->next;
ls->next = nullptr;
delete ls; ls = nullptr;
while (node != nullptr)
{
ls = node;
node = ls->next;
ls->next = nullptr;
delete ls; ls = nullptr;
}
}
void KDTreeSearch::RadiusSearchKDTree(
const vector<vector<double>>& data,
TreeNode* head,
const vector<double>& p,
double r,
SearchList* ls)
{
if (head == nullptr) return;
//1.找到最近点
TreeNode* father, * son;
son = head;
while (son != nullptr && son->chirld[0] != nullptr)
{
father = son;
// 根据当前结点的分割维度以及分割值,决定在哪个子树中搜索邻近点
if (p[son->split_dim] < son->split_center[son->split_dim])
{
son = son->chirld[0];
}
else
{
son = son->chirld[1];
}
}
//最近的点索引
int nind = son->data[0];
double d = Distance(data[nind].data(), p.data(), son->dim);
ls->next = nullptr;
//当前点与目标距离大于搜索半径,直接退出
if (d > r) return;
//2.回溯搜索邻近点
//将最近点加入到链表中
ls->next = new SearchList;
SearchList* node = ls->next;
node->next = nullptr;
node->ind = nind;
node->dist = d;
father = son->father;
//所有点到目标的距离都大于搜索半径
bool all_lager = true;
while (true)
{
//叶子结点的数目
vector<int> leaf;
//根据当前孩子结点在父结点中是否为左结点,
//从而遍历父节点另外的孩子结点
if (son->is_left)
{
ErgodicKDTree(father->chirld[1], leaf);
}
else
{
ErgodicKDTree(father->chirld[0], leaf);
}
//向链表中插入小于搜索距离的点
for (size_t i = 0; i < leaf.size(); i++)
{
int k = leaf[i];
d = Distance(data[k].data(), p.data(), son->dim);
if (d < r)
{
all_lager = false;
SearchList* q = new SearchList;
q->next = nullptr;
q->dist = d;
q->ind = k;
SearchList* qf = ls;
while (true)
{
node = qf->next;
if (node == nullptr) //当前结点为空,直接在链表末尾加入
{
qf->next = q;
break;
}
if (node->dist > q->dist) //当前结点到目标结点的距离大于将要加入的结点,将待加入的结点插入到当前结点前
{
q->next = node;
qf->next = q;
break;
}
else
{
qf = qf->next;
}
}
}
}
//返回上一层
son = father;
father = son->father;
if (all_lager || father == nullptr) break;
}
}
4、其他
(1)KD-Tree内使用索引代替真正的数据与非叶节点不存储数据可以减少内存消耗。
5、完整代码
(1)KDTree.h
#pragma once
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
//KD树结点
struct TreeNode
{
//数据维度,如三维点,二维点
int dim = 1;
//当前结点的分割维度,表示在那个维度进行分割
int split_dim = -1;
//当前结点所有叶子结点存储数据的大小
int data_size = 0;
//是否为左叶子结点,false表示当前结点是父结点的右结点
bool is_left = true;
//孩子结点
TreeNode* chirld[2];
//父结点
TreeNode* father;
//当前结点所有数据的平均值
double* split_center;
//当前结点所存储数据的索引,非叶子结点会在数据分割后清空内存
vector<int> data;
//有参构造,in_dim为输入的维度
TreeNode(int in_dim)
{
dim = in_dim;
chirld[0] = nullptr;
chirld[1] = nullptr;
father = nullptr;
split_center = new double[dim];
memset(split_center, 0, dim * sizeof(double));
}
};
//用于半径搜索,回溯时将小于半径的数据插入到链表中
struct SearchList
{
//索引
int ind = -1;
//距离
double dist = 0.;
//下一个结点
SearchList* next = nullptr;
};
//KD树搜索邻近点
class KDTreeSearch
{
public:
KDTreeSearch();
~KDTreeSearch();
public:
//设置数据维度
void SetDataDim(int dim);
//构建KDTree
void CreateKDTree(const vector<vector<double>>& tree_data);
//K邻近搜索
void KnnSearch(const vector<double>& p, int K, vector<int>& index, vector<double>& dist);
//半径搜索
void RadiusSearch(const vector<double>& p, double r, vector<int>& index, vector<double>& dist);
private:
//数据维度
int m_dim;
//KD树头根结点
TreeNode* m_head;
//输入的数据
vector<vector<double>> m_tree_data;
private:
//计算两个点的欧式距离
static double Distance(const double* p1, const double* p2, int len);
//计算数据的平均值与获取最佳分割维度
static void GetDataFeatrue(
const vector<vector<double>>& data,
const vector<int>& index,
int dim,
double* center,
int& split_dim);
//递归构建KDTree
static void CreateKDTree(const vector<vector<double>>& data, TreeNode* head);
//从某个结点遍历所有元素
static void ErgodicKDTree(TreeNode* head, vector<int>& index);
//删除KD树,在类析构时调用,避免内存泄露
static void DeleteKDTree(TreeNode* head);
//K近邻搜索
static void KnnSearchKDTree(
TreeNode* head,
const vector<double>& p,
int K,
vector<int>& index);
//半径搜索
static void RadiusSearchKDTree(
const vector<vector<double>>& data,
TreeNode* head,
const vector<double>& p,
double r,
SearchList* ls);
};
(2)KDTree.cpp
#include "KDTree.h"
//初始化成员变量
KDTreeSearch::KDTreeSearch() :m_dim(3), m_head(nullptr)
{
}
KDTreeSearch::~KDTreeSearch()
{
DeleteKDTree(m_head);
}
void KDTreeSearch::SetDataDim(int dim)
{
m_dim = dim;
}
void KDTreeSearch::CreateKDTree(const vector<vector<double>>& data)
{
//1.保存输入数据
m_tree_data = data;
//2.建树
m_head = new TreeNode(m_dim);
//3.构建当前数据的索引
m_head->data_size = data.size();
for (size_t i = 0; i < data.size(); i++)
{
m_head->data.push_back(i);
}
//4.建树
CreateKDTree(data, m_head);
}
void KDTreeSearch::KnnSearch(
const vector<double>& p,
int K,
vector<int>& index,
vector<double>& dist)
{
KnnSearchKDTree(m_head, p, K, index);
vector<pair<int, double>> index_dist(index.size());
for (size_t i = 0; i < index.size(); i++)
{
int t = index[i];
const auto& p1 = m_tree_data[t];
index_dist[i].first = t;
index_dist[i].second = Distance(p.data(), p1.data(), m_dim);
}
sort(index_dist.begin(), index_dist.end(), [](pair<int, double> p1, pair<int, double> p2)->bool {return p1.second < p2.second; });
K = K > index_dist.size() ? index_dist.size() : K;
index.resize(K);
dist.resize(K);
for (int i = 0; i < K; i++)
{
index[i] = index_dist[i].first;
dist[i] = index_dist[i].second;
}
}
void KDTreeSearch::RadiusSearch(
const vector<double>& p,
double r,
vector<int>& index,
vector<double>& dist)
{
//1.根据半径搜索邻近点,将搜索结果放入链表中
SearchList* ls = new SearchList;
RadiusSearchKDTree(m_tree_data, m_head, p, r, ls);
//2.输出搜索结果
index.clear();
dist.clear();
SearchList* node = ls->next;
while (node != nullptr)
{
index.push_back(node->ind);
dist.push_back(node->dist);
node = node->next;
}
//3.删除链表
node = ls->next;
ls->next = nullptr;
delete ls; ls = nullptr;
while (node != nullptr)
{
ls = node;
node = ls->next;
ls->next = nullptr;
delete ls; ls = nullptr;
}
}
double KDTreeSearch::Distance(const double* p1, const double* p2, int len)
{
double d = 0.;
for (int i = 0; i < len; i++)
{
d += pow(p1[i] - p2[i], 2);
}
return sqrt(d);
}
void KDTreeSearch::GetDataFeatrue(
const vector<vector<double>>& data,
const vector<int>& index,
int dim,
double* center,
int& split_dim)
{
const int data_size = index.size();
vector<double> average(dim, 0.);
vector<double> standard(dim, 0.);
for (int i = 0; i < data_size; i++)
{
int k = index[i];
const auto& p = data[k];
for (int j = 0; j < dim; j++)
{
average[j] += p[j];
}
}
for (int j = 0; j < dim; j++)
{
average[j] /= data_size;
}
memcpy(center, average.data(), dim * sizeof(double));
for (int i = 0; i < data_size; i++)
{
int k = index[i];
const auto& p = data[k];
for (int j = 0; j < dim; j++)
{
standard[j] += pow(p[j] - average[j], 2);
}
}
for (int j = 0; j < dim; j++)
{
standard[j] = sqrt(standard[j] / data_size);
}
split_dim = 0;
double max_val = standard[split_dim];
for (int i = 1; i < dim; i++)
{
if (max_val < standard[i])
{
max_val = standard[i];
split_dim = i;
}
}
}
void KDTreeSearch::CreateKDTree(const vector<vector<double>>& data, TreeNode* head)
{
//当前结点只有一个数据时,退出
if (head->data_size <= 1) return;
//1.获取分割维度以及平均值
GetDataFeatrue(data, head->data, head->dim, head->split_center, head->split_dim);
//2.构造左右子树
for (int i = 0; i < 2; i++)
{
head->chirld[i] = new TreeNode(head->dim);
head->chirld[i]->is_left = 1 - i; //表示当前结点是否为父结点的左孩子
head->chirld[i]->father = head; //将孩子结点指向父结点
}
//3.根据分割维度对数据进行分割
for (int i = 0; i < head->data_size; i++)
{
int k = head->data[i];
const auto& p = data[k];
if (p[head->split_dim] < head->split_center[head->split_dim])
{
//小于分割值,加入左孩子
head->chirld[0]->data.push_back(k);
head->chirld[0]->data_size++;
}
else
{
//大于分割值,加入右孩子
head->chirld[1]->data.push_back(k);
head->chirld[1]->data_size++;
}
}
//4.清空当前结点存储的数据
vector<int>().swap(head->data);
//5.递归
for (int i = 0; i < 2; i++)
{
CreateKDTree(data, head->chirld[i]);
}
}
void KDTreeSearch::ErgodicKDTree(TreeNode* head, vector<int>& index)
{
if (head == nullptr) return;
//后续遍历叶子结点所有元素
ErgodicKDTree(head->chirld[0], index);
ErgodicKDTree(head->chirld[1], index);
if (head->chirld[0] == nullptr)
index.insert(index.end(), head->data.begin(), head->data.end());
}
void KDTreeSearch::DeleteKDTree(TreeNode* head)
{
if (head == nullptr) return;
DeleteKDTree(head->chirld[0]);
DeleteKDTree(head->chirld[1]);
int k = head->is_left ? 0 : 1;
if (head->father)
head->father->chirld[k] = nullptr;
delete[] head->split_center; head->split_center = nullptr;
delete head; head = nullptr;
}
void KDTreeSearch::KnnSearchKDTree(
TreeNode* head,
const vector<double>& p,
int K,
vector<int>& index)
{
if (head == nullptr) return;
//当前结点数据量大小小于K
if (head->data_size < K)
{
//思想:若K大于当前结点数据量,则K必小于当前结点父结点的数据量,根结点除外
//若当结点为根结点,则从根结点遍历,否则从当前结点的父结点遍历
TreeNode* node = head->father == nullptr ? head : head->father;
ErgodicKDTree(node, index);
return;
}
//根据当前结点的分割维度以及分割值,决定在哪个子树中搜索邻近点
if (p[head->split_dim] < head->split_center[head->split_dim])
{
KnnSearchKDTree(head->chirld[0], p, K, index);
}
else
{
KnnSearchKDTree(head->chirld[1], p, K, index);
}
}
void KDTreeSearch::RadiusSearchKDTree(
const vector<vector<double>>& data,
TreeNode* head,
const vector<double>& p,
double r,
SearchList* ls)
{
if (head == nullptr) return;
//1.找到最近点
TreeNode* father, * son;
son = head;
while (son != nullptr && son->chirld[0] != nullptr)
{
father = son;
// 根据当前结点的分割维度以及分割值,决定在哪个子树中搜索邻近点
if (p[son->split_dim] < son->split_center[son->split_dim])
{
son = son->chirld[0];
}
else
{
son = son->chirld[1];
}
}
//最近的点索引
int nind = son->data[0];
double d = Distance(data[nind].data(), p.data(), son->dim);
ls->next = nullptr;
//当前点与目标距离大于搜索半径,直接退出
if (d > r) return;
//2.回溯搜索邻近点
//将最近点加入到链表中
ls->next = new SearchList;
SearchList* node = ls->next;
node->next = nullptr;
node->ind = nind;
node->dist = d;
father = son->father;
//所有点到目标的距离都大于搜索半径
bool all_lager = true;
while (true)
{
//叶子结点的数目
vector<int> leaf;
//根据当前孩子结点在父结点中是否为左结点,
//从而遍历父节点另外的孩子结点
if (son->is_left)
{
ErgodicKDTree(father->chirld[1], leaf);
}
else
{
ErgodicKDTree(father->chirld[0], leaf);
}
//向链表中插入小于搜索距离的点
for (size_t i = 0; i < leaf.size(); i++)
{
int k = leaf[i];
d = Distance(data[k].data(), p.data(), son->dim);
if (d < r)
{
all_lager = false;
SearchList* q = new SearchList;
q->next = nullptr;
q->dist = d;
q->ind = k;
SearchList* qf = ls;
while (true)
{
node = qf->next;
if (node == nullptr) //当前结点为空,直接在链表末尾加入
{
qf->next = q;
break;
}
if (node->dist > q->dist) //当前结点到目标结点的距离大于将要加入的结点,将待加入的结点插入到当前结点前
{
q->next = node;
qf->next = q;
break;
}
else
{
qf = qf->next;
}
}
}
}
//返回上一层
son = father;
father = son->father;
if (all_lager || father == nullptr) break;
}
}