kd树的ADT实现

实现功能:
kd树的初始化、插入、删除、精确查找、范围查找

#include<iostream>
#include<vector>
#include<algorithm>
#include<queue>
#define INF 0x3f3f3f
using namespace std;


template<class T>
struct kdTreeNode {
	vector<T> data; //该节点的多维数据
	int split; //该节点分割的轴
	kdTreeNode<T> *leftChild; //左孩子
	kdTreeNode<T> *rightChild; //右孩子
	kdTreeNode<T> *parent; //父亲节点
	T *mx; //子树第i维最大值
	T *mn; //子树第i维最小值

	kdTreeNode(){
		leftChild = NULL;
		rightChild = NULL;
		parent = NULL;
		mx = NULL;
		mn = NULL;
	}
	kdTreeNode(vector<T> &d) {
		int k = d.size();
		data = d;
		leftChild = NULL;
		rightChild = NULL;
		parent = NULL;
		mx = new T[k + 1];
		mn = new T[k + 1];
		for (int i = 0; i < k; i++)
			mx[i] = mn[i] = data[i];
	}

	bool cmp(vector<T> &d) {
		for (int i = 0; i < data.size(); i++)
			if (data[i] != d[i])
				return false;
		return true;
	}
	void display() {
		int k = data.size();
		cout << "(";
		for (int i = 0; i < k; i++) {
			if (i != 0)
				cout << ",";
			cout << data[i];
		}
		cout << ")";
	}
};

template<class T>
class kdTree {
private:
	int K; //维度
	kdTreeNode<T> *root; //kd树根节点
	vector<vector<T> > data; //节点向量
	int treeSize; //树节点总数

public:
	kdTree(){}
	kdTree(int k){
		K = k;
		root = NULL;
		treeSize = 0;
	}

	kdTreeNode<T>* rt() { return root; }

	void constructRoot(vector<vector<T> > &num);
	kdTreeNode<T>* construct(vector<vector<T> >&data, int s, int e, int depth);
	void insert(vector<T> newData);
	void erase(vector<T> &Data);
	kdTreeNode<T>* search(vector<T> &Data);
	void rangeSearch(T *rmx, T *rmn, vector<vector<T> > &st, kdTreeNode<T> *temp);
	void reSet(vector<vector<T> > &st, kdTreeNode<T> *p);
	void preO(kdTreeNode<T> *r, vector<vector<T> > &st);
	void display();
	void display(kdTreeNode<T> *r);
	void display2(kdTreeNode<T> *r);
};
template<class T>
void kdTree<T>::constructRoot(vector<vector<T> > &num) {
	K = num[0].size();
	data = num;
	treeSize = num.size();
	root = construct(data, 0, data.size() - 1, 0);
}
template<class T>
kdTreeNode<T>* kdTree<T>::construct(vector<vector<T> >&data, int s,int e, int depth) {
	int size = e - s + 1; //范围内剩余的节点个数
	if (size <= 0)
		return NULL;
	
	int split = depth % K; //该节点划分的维度
	sort(data.begin() + s, data.begin() + e + 1,
		[split](const vector<T> &a, const vector<T> &b) {return a[split] < b[split]; }); //按第split维排序
	int midIndex = s + size / 2;

	kdTreeNode<T> *node = new kdTreeNode<T>(data[midIndex]);
	node->split = split;
	node->leftChild = construct(data, s, midIndex - 1, depth + 1); //递归构造左子树
	node->rightChild = construct(data, midIndex + 1, e, depth + 1); //递归构造右子树
	//记录父亲节点
	if (node->leftChild != NULL)
		node->leftChild->parent = node;
	if (node->rightChild != NULL)
		node->rightChild->parent = node;

	//维护子树节点的最大、最小值
	for (int i = 0; i < K; i++) {
		if (node->leftChild != NULL) {
			node->mx[i] = max(node->mx[i], node->leftChild->mx[i]);
			node->mn[i] = min(node->mn[i], node->leftChild->mn[i]);
		}
		if (node->rightChild != NULL) {
			node->mx[i] = max(node->mx[i], node->rightChild->mx[i]);
			node->mn[i] = min(node->mn[i], node->rightChild->mn[i]);
		}
	}
	return node;
}
template<class T>
void kdTree<T>::insert(vector<T> newData) {
	if (search(newData) != NULL)
		return;

	kdTreeNode<T> *node = new kdTreeNode<T>(newData);
	data.push_back(newData);
	treeSize++;
	kdTreeNode<T> *p = root, *pp = NULL;
	if (root == NULL) {
		node->split = 0;
		K = newData.size();
		root = node;
		return;
	}

	while (p != NULL) {
		pp = p;
		int split = p->split;
		if (newData[split] > p->data[split])
			p = p->rightChild;
		else
			p = p->leftChild;
	}
	if (newData[pp->split] > pp->data[pp->split]) {
		node->parent = pp;
		node->split = (pp->split + 1) % K;
		pp->rightChild = node;
	}
	else {
		node->parent = pp;
		node->split = (pp->split + 1) % K;
		pp->leftChild = node;
	}

	while (pp != NULL) {
		for (int i = 0; i < K; i++) {
			pp->mx[i] = max(pp->mx[i], newData[i]);
			pp->mn[i] = min(pp->mn[i], newData[i]);
		}
		pp = pp->parent;
	}
}
template<class T>
void kdTree<T>::erase(vector<T> &Data) {
	kdTreeNode<T> *temp = search(Data);
	if (temp == NULL) {
		cout << "删除失败!" << endl;
		return;
	}
	data.erase(find(data.begin(), data.end(), Data));

	vector<vector<T> > st;
	reSet(st, temp->leftChild);
	reSet(st, temp->rightChild);
	temp = temp->parent;
	if (temp == NULL) {
		root = construct(data, 0, data.size() - 1, 0);
		return;
	}

	int *mmn, *mmx;
	if (temp->rightChild != NULL && temp->rightChild->cmp(Data)) {
		temp->rightChild = construct(st, 0, st.size() - 1, temp->split + 1);
		mmx = temp->rightChild->mx;
		mmn = temp->rightChild->mn;
	}
	else {
		temp->leftChild = construct(st, 0, st.size() - 1, temp->split + 1);
		mmx = temp->leftChild->mx;
		mmn = temp->leftChild->mn;
	}

	while (temp != NULL) {
		for (int i = 0; i < K; i++) {
			temp->mx[i] = max(temp->mx[i], mmx[i]);
			temp->mn[i] = min(temp->mn[i], mmn[i]);
		}
		temp = temp->parent;
	}
}
template<class T>
kdTreeNode<T>* kdTree<T>::search(vector<T> &Data) {
	kdTreeNode<T> *temp = root;
	while (temp != NULL) {
		if (temp->cmp(Data))
			break;
		int k = temp->split;
		if (temp->leftChild != NULL
			&& Data[k] <= temp->leftChild->mx[k] && Data[k] >= temp->leftChild->mn[k]) {
			temp = temp->leftChild;
		}
		else if (temp->rightChild != NULL
			&& Data[k] <= temp->rightChild->mx[k] && Data[k] >= temp->rightChild->mn[k]) {
			temp = temp->rightChild;
		}
		else
			temp = NULL;
	}
	return temp;
}
template<class T>
void kdTree<T>::reSet(vector<vector<T> > &st, kdTreeNode<T> *p) {
	if (p != NULL) {
		st.push_back(p->data);
		reSet(st, p->leftChild);
		reSet(st, p->rightChild);
	}
}
template<class T>
void kdTree<T>::rangeSearch(T *rmx, T *rmn, vector<vector<T> > &st, kdTreeNode<T> *temp) {
	if (temp != NULL) {
		bool flag1 = true, flag2 = true;
		for (int i = 0; i < K; i++) {
			if (!(temp->mx[i] <= rmx[i] && temp->mn[i] >= rmn[i])) {
				//不包含
				flag1 = false;
			}
			if (temp->mx[i] < rmn[i] || temp->mn[i] > rmx[i]) {
				//无交集
				flag2 = false;
			}
		}
		if (!flag2)
			return;
		if (flag1) {
			preO(temp, st);
			return;
		}

		if (temp->data[0] <= rmx[0] && temp->data[0] >= rmn[0] 
			&& temp->data[1] <= rmx[1] && temp->data[1] >= rmn[1])
			st.push_back(temp->data);
		rangeSearch(rmx, rmn, st, temp->leftChild);
		rangeSearch(rmx, rmn, st, temp->rightChild);
	}
}
template<class T>
void kdTree<T>::preO(kdTreeNode<T> *r, vector<vector<T> > &st) {
	if (r != NULL) {
		st.push_back(r->data);
		preO(r->leftChild, st);
		preO(r->rightChild, st);
	}
}
template<class T>
void kdTree<T>::display() {
	queue<kdTreeNode<T>* > q;
	q.push(root);
	int count = 0, layer = 1;
	while (!q.empty()) {
		kdTreeNode<T> *temp = q.front();
		q.pop();
		temp->display();
		count++;
		if (pow(2, layer) - 1 == count) {
			cout << endl;
			layer++;
		}
		else
			cout << "  ";
		if (temp->leftChild != NULL)
			q.push(temp->leftChild);
		if (temp->rightChild != NULL)
			q.push(temp->rightChild);
	}
}
template<class T>
void kdTree<T>::display(kdTreeNode<T> *r) {
	if (r != NULL) {
		r->display();
		cout << " ";
		display(r->leftChild);
		display(r->rightChild);
	}
}
template<class T>
void kdTree<T>::display2(kdTreeNode<T> *r) {
	if (r != NULL) {
		display2(r->leftChild);
		r->display();
		cout << " ";
		display2(r->rightChild);
	}
}


template<class T>
void input(int &k, vector<vector<T> > &num) {
	int a;
	cout << "输入维度k: ";
	cin >> k;
	cout << "输入初始构造的节点数a: ";
	cin >> a;
	cout << "输入a个k维数据(用空格分开各个数): " << endl;
	for (int i = 0; i < a; i++) {
		vector<int> t;
		for (int i = 0; i < k; i++) {
			int temp;
			cin >> temp;
			t.push_back(temp);
		}
		num.push_back(t);
	}
}
template<class T>
void input_opData(int &k, vector<T> &data) {
	for (int i = 0; i < k; i++) {
		int temp;
		cin >> temp;
		data.push_back(temp);
	}
}
template<class T>
void input_ranData(int &k, T *s) {
	while (1) {
		int i, val;
		cout << "输入维度: ";
		cin >> i;
		if (i == k + 1)
			break;
		cout << "输入值: ";
		cin >> val;
		s[i - 1] = val;
	}
}


int main(void) {
	kdTree<int> tree;
	int k;
	vector<vector<int> > num;
	//初始输入并构造
	input(k, num);
	tree.constructRoot(num);

	//操作输入
	while (1) {
		int op;
		cout << "请输入操作码op (1:插入 2:删除 3:精确查找 4:范围查询 5:显示 6:退出): ";
		cin >> op;
		if (op == 6)
			break;
		if (op == 1) {
			cout << "输入要插入的k维数据点: ";
			vector<int> data;
			input_opData(k, data);
			tree.insert(data);
			cout << endl;
		}
		else if (op == 2) {
			cout << "输入要删除的k维数据点: ";
			vector<int> data;
			input_opData(k, data);
			tree.erase(data);
			cout << endl;
		}
		else if (op == 3) {
			cout << "输入要精确查找的k维数据点: ";
			vector<int> data;
			input_opData(k, data);
			if (tree.search(data) == NULL)
				cout << "没有找到!!!" << endl;
			else
				cout << "找到" << endl;
			cout << endl;
		}
		else if (op == 4) {
			int *mx = new int[k];
			int *mn = new int[k];
			for (int i = 0; i < k; i++) {
				mx[i] = INF;
				mn[i] = -INF;
			}

			cout << "输入查询范围 各维度上限(k+1表示停止输入): " << endl;
			input_ranData(k, mx);
			cout << "输入查询范围 各维度下限(k+1表示停止输入): " << endl;
			input_ranData(k, mn);

			vector<vector<int> > res;
			tree.rangeSearch(mx, mn, res, tree.rt());
			if (res.empty())
				cout << "没有在范围内的节点!!!" << endl;
			else {
				cout << "输入范围内的值如下: " << endl;
				for (int i = 0; i < res.size(); i++) {
					cout << "(";
					for (int j = 0; j < res[i].size(); j++) {
						if (j != 0)
							cout << ",";
						cout << res[i][j];
					}
					cout << "), ";
				}
			}
			cout << endl;
		}
		else if (op == 5) {
			cout << "当前kd-tree为: " << endl;
			tree.display();
			cout << endl << "前序: ";
			tree.display(tree.rt());
			cout << endl << "中序: ";
			tree.display2(tree.rt());
			cout << endl << endl;
		}
	}
	cout << "byebye!!!" << endl;

	system("pause");
	return 0;
}

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值