kd树

K-近邻法

距离度量标准:欧式距离或更一般的 L p L_p Lp 距离。

k 值的选择:k 值小时,k 近邻模型更复杂;k 值大时,模型更简单(当 k=N 时,最简单);用交叉验证法取得最合适的 k 值。

分类决策原则:多数表决。

构造kd树

kd树用于搜索与规定点空间距离最小的点。

kd树与线段树类似,线段树存某一区间,kd树存某一k维空间。

构造kd树的方法也是从一整个空间开始,递归往下分配空间。

每个节点存其所在维度,父子节点,以便之后搜索。

#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int k = 2;		//k维度
const int maxn = 100;		//最大节点数
int cmp_dim;			//排序时的维度
vector<vector<int> >dat;	//存节点坐标
struct Node {			//节点
	int value[k];			//坐标
	int dim;			//节点所在维度
	Node* parant=NULL;
	Node* left_ch=NULL;	//父节点,子节点
	Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
	return a.value[cmp_dim] < b.value[cmp_dim];
}

Node* build(int l,int r,int d) {
	if (l > r)return NULL;
	int m;
	cmp_dim = d;
	if ((l + r) & 1)m = 1+(l + r) >> 1;	//中位点
	else m = (l + r) >> 1;
	nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
	kdnode[m].dim = d;
	 Node* lch = build(l, m-1, (d + 1) % k);	//递归建树
	if(lch!=NULL)
		lch->parant = kdnode + m;
	Node* rch = build(m + 1, r, (d + 1) % k);
	if(rch!=NULL)
		rch->parant = kdnode + m;
	return kdnode + m;
	
}

int main() {
	int n = 6;
	dat.push_back({ 2,3 });
	dat.push_back({ 5,4 });
	dat.push_back({ 9,6 });
	dat.push_back({ 4,7 });
	dat.push_back({ 8,1 });
	dat.push_back({ 7,2 });
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < k; j++) {
			kdnode[i].value[j] = dat[i][j];
		}
	}
	build(0, n - 1, 0);

	return 0;
}

搜索kd树

  1. 找到包含目标点的叶节点

  2. 以此节点为当前最近节点

  3. 向上回退。同时检查:

    1. 若父节点更近,更新;

    2. 若兄弟节点空间可能存在更近的点(到轴距离小于当前最小距离),检查兄弟节点。

  4. 当回到根节点,搜索结束。最终最近节点即为结果。

所有代码:

#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
const int k = 2;		//k维度
const int maxn = 100;		//最大节点数
int cmp_dim;			//排序时的维度
vector<vector<int> >dat;	//存节点坐标
struct Node {			//节点
	int value[k];			//坐标
	int dim;			//节点所在维度
	Node* parent=NULL;
	Node* left_ch=NULL;	//父节点,子节点
	Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
	return a.value[cmp_dim] < b.value[cmp_dim];
}

Node* build(int l,int r,int d) {
	if (l > r)return NULL;
	int m;
	cmp_dim = d;
	if ((l + r) & 1)m = 1+(l + r) >> 1;	//中位点
	else m = (l + r) >> 1;
	nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
	kdnode[m].dim = d;
	 Node* lch = build(l, m-1, (d + 1) % k);	//递归建树
	 if (lch != NULL)
	 {
		 lch->parent = kdnode + m;
		 kdnode[m].left_ch = lch;
	 }
	Node* rch = build(m + 1, r, (d + 1) % k);
	if (rch != NULL)
	{
		rch->parent = kdnode + m;
		kdnode[m].right_ch = rch;
	}
	return kdnode + m;
	
}

//distance返回欧式距离的平方
int distance(int* a,int* b) {
	int res = 0;
	for (int i = 0; i < k; i++)
		res += (a[i]-b[i])*(a[i]-b[i]);
	return res;
}

Node* root;			//kd树的根
int min_dis;			//最近点与目标点距离
Node* nearest;		//最近点

//找到包含目标点的叶节点
Node* the_one(int* arr) {
	int d = 0;
	Node* r=root;
	Node* tmp=r;
	while (r != NULL) {
		tmp = r;
		if (r->value[d] > arr[d]) {
			r = r->left_ch;
		}
		else
			r = r->right_ch;
		d = (d + 1) % k;
	}
	min_dis = distance(arr, tmp->value);
	nearest = tmp;
	return tmp;
}
//在某一节点 从上往下 找最近点
void search_son(Node* son, int* tar) {
	if (son == NULL)
		return;
	if (distance(son->value, tar) < min_dis) {	//如果当前点更近,则更新
		nearest = son;
		min_dis = distance(son->value, tar);
	}
	//若轴与目标点距离小于当前最小距离,则两边都要找
	if (pow(son->value[son->dim] - tar[son->dim], 2) < min_dis) {
		search_son(son->left_ch, tar);
		search_son(son->right_ch, tar);
	}
	//只要找一边
	else {
		//目标点在轴左边,找左边空间
		if (son->value[son->dim] > tar[son->dim])
			search_son(son->left_ch, tar);
		else
			search_son(son->right_ch, tar);
	}
}

//在某一点 从下往上 找最近点
void search_all(Node* now,int tar[k]) {
	Node* par = now->parent;
	while (now->parent != NULL) {
		now = now->parent;
		//若父节点更近,则更新
		if (distance(tar, now->value) < min_dis) {
			nearest = now;
			min_dis = distance(tar, now->value);
		}
		Node* son;
		//若轴与目标点距离小于当前最小距离,还要在另一边找
		if (pow(now->value[now->dim], 2) < min_dis) {
			//目标点在轴右边,找右边空间
			if (now->value[now->dim] > tar[now->dim])
				son = now->right_ch;
			else
				son = now->left_ch;
			//确定在哪边空间后,从上往下找
			search_son(son,tar);
		}
	}
}
int main() {
	int n = 6;
	dat.push_back({ 2,3 });
	dat.push_back({ 5,4 });
	dat.push_back({ 9,6 });
	dat.push_back({ 4,7 });
	dat.push_back({ 8,1 });
	dat.push_back({ 7,2 });
	for (int i = 0; i < n; i++) {
		for (int j = 0; j < k; j++) {
			kdnode[i].value[j] = dat[i][j];
		}
	}
	root=build(0, n - 1, 0);
	//测试
	int a[2] = { 8,2 };
	Node* now = the_one(a);
	search_all(now, a);
	//distance返回的是最小距离的平方,结果还应开根号
	cout << distance(nearest->value, a);
	return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值