实现功能:
kd树的初始化、插入、删除、精确查找、范围查找
#include<iostream>
#include<vector>
#include<algorithm>
#include<queue>
#define INF 0x3f3f3f
using namespace std;
template<class T>
struct kdTreeNode {
vector<T> data; //该节点的多维数据
int split; //该节点分割的轴
kdTreeNode<T> *leftChild; //左孩子
kdTreeNode<T> *rightChild; //右孩子
kdTreeNode<T> *parent; //父亲节点
T *mx; //子树第i维最大值
T *mn; //子树第i维最小值
kdTreeNode(){
leftChild = NULL;
rightChild = NULL;
parent = NULL;
mx = NULL;
mn = NULL;
}
kdTreeNode(vector<T> &d) {
int k = d.size();
data = d;
leftChild = NULL;
rightChild = NULL;
parent = NULL;
mx = new T[k + 1];
mn = new T[k + 1];
for (int i = 0; i < k; i++)
mx[i] = mn[i] = data[i];
}
bool cmp(vector<T> &d) {
for (int i = 0; i < data.size(); i++)
if (data[i] != d[i])
return false;
return true;
}
void display() {
int k = data.size();
cout << "(";
for (int i = 0; i < k; i++) {
if (i != 0)
cout << ",";
cout << data[i];
}
cout << ")";
}
};
template<class T>
class kdTree {
private:
int K; //维度
kdTreeNode<T> *root; //kd树根节点
vector<vector<T> > data; //节点向量
int treeSize; //树节点总数
public:
kdTree(){}
kdTree(int k){
K = k;
root = NULL;
treeSize = 0;
}
kdTreeNode<T>* rt() { return root; }
void constructRoot(vector<vector<T> > &num);
kdTreeNode<T>* construct(vector<vector<T> >&data, int s, int e, int depth);
void insert(vector<T> newData);
void erase(vector<T> &Data);
kdTreeNode<T>* search(vector<T> &Data);
void rangeSearch(T *rmx, T *rmn, vector<vector<T> > &st, kdTreeNode<T> *temp);
void reSet(vector<vector<T> > &st, kdTreeNode<T> *p);
void preO(kdTreeNode<T> *r, vector<vector<T> > &st);
void display();
void display(kdTreeNode<T> *r);
void display2(kdTreeNode<T> *r);
};
template<class T>
void kdTree<T>::constructRoot(vector<vector<T> > &num) {
K = num[0].size();
data = num;
treeSize = num.size();
root = construct(data, 0, data.size() - 1, 0);
}
template<class T>
kdTreeNode<T>* kdTree<T>::construct(vector<vector<T> >&data, int s,int e, int depth) {
int size = e - s + 1; //范围内剩余的节点个数
if (size <= 0)
return NULL;
int split = depth % K; //该节点划分的维度
sort(data.begin() + s, data.begin() + e + 1,
[split](const vector<T> &a, const vector<T> &b) {return a[split] < b[split]; }); //按第split维排序
int midIndex = s + size / 2;
kdTreeNode<T> *node = new kdTreeNode<T>(data[midIndex]);
node->split = split;
node->leftChild = construct(data, s, midIndex - 1, depth + 1); //递归构造左子树
node->rightChild = construct(data, midIndex + 1, e, depth + 1); //递归构造右子树
//记录父亲节点
if (node->leftChild != NULL)
node->leftChild->parent = node;
if (node->rightChild != NULL)
node->rightChild->parent = node;
//维护子树节点的最大、最小值
for (int i = 0; i < K; i++) {
if (node->leftChild != NULL) {
node->mx[i] = max(node->mx[i], node->leftChild->mx[i]);
node->mn[i] = min(node->mn[i], node->leftChild->mn[i]);
}
if (node->rightChild != NULL) {
node->mx[i] = max(node->mx[i], node->rightChild->mx[i]);
node->mn[i] = min(node->mn[i], node->rightChild->mn[i]);
}
}
return node;
}
template<class T>
void kdTree<T>::insert(vector<T> newData) {
if (search(newData) != NULL)
return;
kdTreeNode<T> *node = new kdTreeNode<T>(newData);
data.push_back(newData);
treeSize++;
kdTreeNode<T> *p = root, *pp = NULL;
if (root == NULL) {
node->split = 0;
K = newData.size();
root = node;
return;
}
while (p != NULL) {
pp = p;
int split = p->split;
if (newData[split] > p->data[split])
p = p->rightChild;
else
p = p->leftChild;
}
if (newData[pp->split] > pp->data[pp->split]) {
node->parent = pp;
node->split = (pp->split + 1) % K;
pp->rightChild = node;
}
else {
node->parent = pp;
node->split = (pp->split + 1) % K;
pp->leftChild = node;
}
while (pp != NULL) {
for (int i = 0; i < K; i++) {
pp->mx[i] = max(pp->mx[i], newData[i]);
pp->mn[i] = min(pp->mn[i], newData[i]);
}
pp = pp->parent;
}
}
template<class T>
void kdTree<T>::erase(vector<T> &Data) {
kdTreeNode<T> *temp = search(Data);
if (temp == NULL) {
cout << "删除失败!" << endl;
return;
}
data.erase(find(data.begin(), data.end(), Data));
vector<vector<T> > st;
reSet(st, temp->leftChild);
reSet(st, temp->rightChild);
temp = temp->parent;
if (temp == NULL) {
root = construct(data, 0, data.size() - 1, 0);
return;
}
int *mmn, *mmx;
if (temp->rightChild != NULL && temp->rightChild->cmp(Data)) {
temp->rightChild = construct(st, 0, st.size() - 1, temp->split + 1);
mmx = temp->rightChild->mx;
mmn = temp->rightChild->mn;
}
else {
temp->leftChild = construct(st, 0, st.size() - 1, temp->split + 1);
mmx = temp->leftChild->mx;
mmn = temp->leftChild->mn;
}
while (temp != NULL) {
for (int i = 0; i < K; i++) {
temp->mx[i] = max(temp->mx[i], mmx[i]);
temp->mn[i] = min(temp->mn[i], mmn[i]);
}
temp = temp->parent;
}
}
template<class T>
kdTreeNode<T>* kdTree<T>::search(vector<T> &Data) {
kdTreeNode<T> *temp = root;
while (temp != NULL) {
if (temp->cmp(Data))
break;
int k = temp->split;
if (temp->leftChild != NULL
&& Data[k] <= temp->leftChild->mx[k] && Data[k] >= temp->leftChild->mn[k]) {
temp = temp->leftChild;
}
else if (temp->rightChild != NULL
&& Data[k] <= temp->rightChild->mx[k] && Data[k] >= temp->rightChild->mn[k]) {
temp = temp->rightChild;
}
else
temp = NULL;
}
return temp;
}
template<class T>
void kdTree<T>::reSet(vector<vector<T> > &st, kdTreeNode<T> *p) {
if (p != NULL) {
st.push_back(p->data);
reSet(st, p->leftChild);
reSet(st, p->rightChild);
}
}
template<class T>
void kdTree<T>::rangeSearch(T *rmx, T *rmn, vector<vector<T> > &st, kdTreeNode<T> *temp) {
if (temp != NULL) {
bool flag1 = true, flag2 = true;
for (int i = 0; i < K; i++) {
if (!(temp->mx[i] <= rmx[i] && temp->mn[i] >= rmn[i])) {
//不包含
flag1 = false;
}
if (temp->mx[i] < rmn[i] || temp->mn[i] > rmx[i]) {
//无交集
flag2 = false;
}
}
if (!flag2)
return;
if (flag1) {
preO(temp, st);
return;
}
if (temp->data[0] <= rmx[0] && temp->data[0] >= rmn[0]
&& temp->data[1] <= rmx[1] && temp->data[1] >= rmn[1])
st.push_back(temp->data);
rangeSearch(rmx, rmn, st, temp->leftChild);
rangeSearch(rmx, rmn, st, temp->rightChild);
}
}
template<class T>
void kdTree<T>::preO(kdTreeNode<T> *r, vector<vector<T> > &st) {
if (r != NULL) {
st.push_back(r->data);
preO(r->leftChild, st);
preO(r->rightChild, st);
}
}
template<class T>
void kdTree<T>::display() {
queue<kdTreeNode<T>* > q;
q.push(root);
int count = 0, layer = 1;
while (!q.empty()) {
kdTreeNode<T> *temp = q.front();
q.pop();
temp->display();
count++;
if (pow(2, layer) - 1 == count) {
cout << endl;
layer++;
}
else
cout << " ";
if (temp->leftChild != NULL)
q.push(temp->leftChild);
if (temp->rightChild != NULL)
q.push(temp->rightChild);
}
}
template<class T>
void kdTree<T>::display(kdTreeNode<T> *r) {
if (r != NULL) {
r->display();
cout << " ";
display(r->leftChild);
display(r->rightChild);
}
}
template<class T>
void kdTree<T>::display2(kdTreeNode<T> *r) {
if (r != NULL) {
display2(r->leftChild);
r->display();
cout << " ";
display2(r->rightChild);
}
}
template<class T>
void input(int &k, vector<vector<T> > &num) {
int a;
cout << "输入维度k: ";
cin >> k;
cout << "输入初始构造的节点数a: ";
cin >> a;
cout << "输入a个k维数据(用空格分开各个数): " << endl;
for (int i = 0; i < a; i++) {
vector<int> t;
for (int i = 0; i < k; i++) {
int temp;
cin >> temp;
t.push_back(temp);
}
num.push_back(t);
}
}
template<class T>
void input_opData(int &k, vector<T> &data) {
for (int i = 0; i < k; i++) {
int temp;
cin >> temp;
data.push_back(temp);
}
}
template<class T>
void input_ranData(int &k, T *s) {
while (1) {
int i, val;
cout << "输入维度: ";
cin >> i;
if (i == k + 1)
break;
cout << "输入值: ";
cin >> val;
s[i - 1] = val;
}
}
int main(void) {
kdTree<int> tree;
int k;
vector<vector<int> > num;
//初始输入并构造
input(k, num);
tree.constructRoot(num);
//操作输入
while (1) {
int op;
cout << "请输入操作码op (1:插入 2:删除 3:精确查找 4:范围查询 5:显示 6:退出): ";
cin >> op;
if (op == 6)
break;
if (op == 1) {
cout << "输入要插入的k维数据点: ";
vector<int> data;
input_opData(k, data);
tree.insert(data);
cout << endl;
}
else if (op == 2) {
cout << "输入要删除的k维数据点: ";
vector<int> data;
input_opData(k, data);
tree.erase(data);
cout << endl;
}
else if (op == 3) {
cout << "输入要精确查找的k维数据点: ";
vector<int> data;
input_opData(k, data);
if (tree.search(data) == NULL)
cout << "没有找到!!!" << endl;
else
cout << "找到" << endl;
cout << endl;
}
else if (op == 4) {
int *mx = new int[k];
int *mn = new int[k];
for (int i = 0; i < k; i++) {
mx[i] = INF;
mn[i] = -INF;
}
cout << "输入查询范围 各维度上限(k+1表示停止输入): " << endl;
input_ranData(k, mx);
cout << "输入查询范围 各维度下限(k+1表示停止输入): " << endl;
input_ranData(k, mn);
vector<vector<int> > res;
tree.rangeSearch(mx, mn, res, tree.rt());
if (res.empty())
cout << "没有在范围内的节点!!!" << endl;
else {
cout << "输入范围内的值如下: " << endl;
for (int i = 0; i < res.size(); i++) {
cout << "(";
for (int j = 0; j < res[i].size(); j++) {
if (j != 0)
cout << ",";
cout << res[i][j];
}
cout << "), ";
}
}
cout << endl;
}
else if (op == 5) {
cout << "当前kd-tree为: " << endl;
tree.display();
cout << endl << "前序: ";
tree.display(tree.rt());
cout << endl << "中序: ";
tree.display2(tree.rt());
cout << endl << endl;
}
}
cout << "byebye!!!" << endl;
system("pause");
return 0;
}