k邻近算法之 搜索kd树从 最邻近到 k邻近

  • 写在前面

在一番摸索之下,博主利用国庆的时间,系统地了解k邻近算法。从第一次接触,到完整地用c++代码实现利用kd树来进行 k邻近搜索。弄清了很多细节,但也可能有很多不足之处,小伙伴们尽情板砖。

算法的主体都有详细的程序,只是将样本放在矩阵库(armadillo)中。当然小伙伴们用opencv库更好,博主也打算转用opencv库。

  • k邻近分类算法简述

首先我们有m个样本,每个样本有n个特征,而且每个样本有自己的分类标签。此时,给出一个测试样例,如何去判断该样例是属于什么分类呢。k邻近算法给我们提供了一个思路:找到距离测试样本最近的k个样本。这k个样本大多数属于哪个类别,那么该测试样例属于这个类别的可能性就比较大。这里的距离就使用欧式距离(平方开根号)。

那么如何去找到这k个样本呢?首先想到的办法就是把m个样本全部算一遍距离。这个方案在样本m数据量很大时,会比较耗时。现在有一种方案,就是使用kd树去寻找这k个样本。其精髓与二分查找类似。

  • kd树的建立

kd树的结构与二叉树类似

简述步骤:

  1. 我们有个样本矩阵m行,n列。m是指有m个样本,n指有n维特征。那么首先找到第一个特征的中间值所在的样本(即找到矩阵第一列的中间值所在的行)。
  2. 将这个样本作为根节点。然后将剩余的样本分成两部分,第一部分中样本的第一个特征值都小于根节点的第一个特征值。第二部分中样本的第一个特征值都大于根节点的第一个特征值。
  3. 然后我们同样对这两个矩阵进行分割。此时我们用第二列特征值,同样找到中间值。以此不断分割,直到所有的样本都在kd树的节点上。

我们发现步骤2与步骤三可以用相同的函数实现,这不是勾引我们用递归吗。。当然这与二叉树不是类似吗。

具体实现:直接上代码,代码中用了armadillo矩阵库(只是辅助用,不包含算法),和大多数矩阵用法类似,具体可以参考

http://arma.sourceforge.net/docs.html

 

//kd树节点结构体
struct kdTreeNode{
	mat data;				//节点数据向量 1*n  n为特征数
	kdTreeNode * leftP;		//左子树指针
	kdTreeNode * rightP;	//右子树指针
	kdTreeNode * fatherP;	//父亲节点指针,竟然没用到
	int splitIndex;			//当前节点分割数据的特征索引,目前是轮换特征
	kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//构造函数
};
//使用递归建立kd树
/*
	nodeP 结构体指针
	data  输入数据矩阵,m*n m为样本数 n为特征数
	depth 节点深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
	//如果输入的数据矩阵为空,则返回
	if(data.n_rows == 0)
	{
		return;
	}
	//如果只有一个样本,则直接将数据给节点
	if(data.n_rows == 1)
	{
		nodeP->data = data.row(0);
		return;
	}
	//此时样本数大于等于2
	int splitFeatureIndex = depth % data.n_cols;	//计算分割数据的特征索引
	nodeP->splitIndex = splitFeatureIndex;		
	vec splitFeature = data.col(splitFeatureIndex);	//取得特征列向量
	vec splitOrder = sort(splitFeature);			//排序
	double medianValue = splitOrder[data.n_rows/2]; //取中值
	mat subsetLeft,subsetRight;						//左右矩阵
	//接下来是将样本矩阵分为左右两个矩阵
	for(int i=0;i<data.n_rows;i++)
	{
		if(nodeP->data.empty()&&splitFeature[i] == medianValue)
		{
			//将中间样本给节点
			nodeP->data = data.row(i);
		}else{
			if(splitFeature[i] < medianValue)
			{
				subsetLeft = join_vert( subsetLeft,data.row(i));
			}else{
				subsetRight = join_vert( subsetRight,data.row(i));
			}
		}
	}	
	nodeP->leftP = new kdTreeNode;
	nodeP->leftP->fatherP = nodeP;
	nodeP->rightP = new kdTreeNode;
	nodeP->rightP->fatherP = nodeP;
	//递归进入左右子树
	buildKdTree(nodeP->leftP,subsetLeft,depth+1);
	buildKdTree(nodeP->rightP,subsetRight,depth+1);
}

建立完之后,我们可以写个遍历程序验证下

//中序遍历
void inorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}		
	inorder_traverse(nodeP->leftP);
	cout<<nodeP->data<<endl;
	inorder_traverse(nodeP->rightP);
}
//前序遍历
void preorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}	
	cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
	preorder_traverse(nodeP->leftP);	
	preorder_traverse(nodeP->rightP);
}
  • 最邻近搜索

现有一个测试目标,为1*n的矩阵,n为特征数,现在就在样本矩阵找到与这个目标距离最近的样本。

搜索就不用迭代了,这样也好理解

具体步骤:

  • 第一步:利用二叉搜索到子节点并将,节点指针顺序压入堆栈stack中

使用测试样本的第一个特征值,与根节点比较,若特征值小于或等于根节点的第一个特征值,则下一步在左子树搜索,反之则在右子树。如此循环直至找到叶子节点。并且在这个过程中,不断将得到的节点压入堆栈中,这就形成我们的初步搜索路径,按照堆栈先进后出的原则,堆栈最顶端的元素将是叶子节点。

  • 第二步:通过保存在堆栈中的搜索路径回溯

首先计算出最后的叶子结点与目标的距离保存为最短距离minDIstance,并将它设置成最邻近样本

(1)从堆栈中取出一个节点

(2)若是叶子结点,则只有一个步骤

                      1.计算与目标的距离,若小于保存的minDistance,则将mindistance替换为当前计算的距离,并将当前样本设置为最邻近样本。若大于则跳过。

         若不是叶子节点,则有三个步骤

                       1.此步骤与叶子结点的那个步骤相同

                      2.计算目标与当前节点划分的超平面的距离(这个可以这么理解:假如在二维直角坐标系中计算某个点与x轴或与x轴平行的一条线的距离,怎么算,这个很显然吧~,同样的,与超平面的距离你应该知道怎么算了)

                     3.假如目标到超平面的距离大于记录的minDistance。则不执行任何操作,等着下一次回溯。反之若小于minDistance,则说明在当前节点的另一个子树中可能有更近的点。这样,我们找到另一个子树的节点,将它压入到堆栈中(即加入到搜索路径中)并且按照 第一步 的方法,一边往下二叉搜索,一边将遇到的节点压入堆栈中。然后就等着下一次回溯(即从堆栈中取节点)

(3)不断循环步骤(2)直到堆栈中的数据都取出,即搜索路径都走了一遍。

代码到贴k邻近再贴吧

  • k 邻近搜索

网上一直没找到关于k邻近的算法步骤,我自己在最邻近的基础上修改了下,发现可行,不知与官方的算法是否有差距。

其中主体算法步骤与最邻近相似。

就说明一下不同之处:

第一点不同:相较于最邻近只需找一个样本,k邻近算法中,我们定义有个结构体数组来记录k个样本

第二点不同:我们不管他是不是真的最近的k个样本先找到k个再说。所以在收集到k个样本之前,无视最邻近算法中是否小于minDistance这个条件,统统收集到数组中,直到找满k个样本。

第三点不同:找到k个样本之后,立刻找出k个样本中的最大距离的样本,此距离记为maxD,因为他最不靠谱。。恢复最邻近算法中的条件,找到一个比maxD距离更小的样本,就可以将它替换。此时重新计算最maxD。直到堆栈为空。

  • 最后

把所有代码都贴上吧。所有代码都在一个文件中。想要运行需要加矩阵库armadillo。x64

如果需要工程文件的,我也上传了。里面包含了矩阵库的lib,dll,头文件,可以直接运行,不需要配置。用的是vs2012 x64

https://download.csdn.net/download/qq_32478489/10707377

#include <iostream>
#include <time.h>
#include <armadillo>
#include <stack>
using namespace std;
using namespace arma;
//kd树节点结构体
struct kdTreeNode{
	mat data;				//节点数据向量 1*n  n为特征数
	kdTreeNode * leftP;		//左子树指针
	kdTreeNode * rightP;	//右子树指针
	kdTreeNode * fatherP;	//父亲节点指针,竟然没用到
	int splitIndex;			//当前节点分割数据的特征索引,目前是轮换特征
	kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//构造函数
};
//使用递归建立kd树
/*
	nodeP 结构体指针
	data  输入数据矩阵,m*n m为样本数 n为特征数
	depth 节点深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
	//如果输入的数据矩阵为空,则返回
	if(data.n_rows == 0)
	{
		return;
	}
	//如果只有一个样本,则直接将数据给节点
	if(data.n_rows == 1)
	{
		nodeP->data = data.row(0);
		return;
	}
	//此时样本数大于等于2
	int splitFeatureIndex = depth % data.n_cols;	//计算分割数据的特征索引
	nodeP->splitIndex = splitFeatureIndex;		
	vec splitFeature = data.col(splitFeatureIndex);	//取得特征列向量
	vec splitOrder = sort(splitFeature);			//排序
	double medianValue = splitOrder[data.n_rows/2]; //取中值
	mat subsetLeft,subsetRight;						//左右矩阵
	//接下来是将样本矩阵分为左右两个矩阵
	for(int i=0;i<data.n_rows;i++)
	{
		if(nodeP->data.empty()&&splitFeature[i] == medianValue)
		{
			//将中间样本给节点
			nodeP->data = data.row(i);
		}else{
			if(splitFeature[i] < medianValue)
			{
				subsetLeft = join_vert( subsetLeft,data.row(i));
			}else{
				subsetRight = join_vert( subsetRight,data.row(i));
			}
		}
	}	
	nodeP->leftP = new kdTreeNode;
	nodeP->leftP->fatherP = nodeP;
	nodeP->rightP = new kdTreeNode;
	nodeP->rightP->fatherP = nodeP;
	//递归进入左右子树
	buildKdTree(nodeP->leftP,subsetLeft,depth+1);
	buildKdTree(nodeP->rightP,subsetRight,depth+1);
}
//中序遍历
void inorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}		
	inorder_traverse(nodeP->leftP);
	cout<<nodeP->data<<endl;
	inorder_traverse(nodeP->rightP);
}
//前序遍历
void preorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}	
	cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
	preorder_traverse(nodeP->leftP);	
	preorder_traverse(nodeP->rightP);
}
//计算欧氏距离
double EuclidianDis(mat & a,mat & b)
{
	vec v1 = a.row(0).t();
	vec v2 = b.row(0).t();	
	vec c = v1-v2;
	return norm(c,2);
}
struct KNN{
	kdTreeNode * nodeP;
	double distance;
	KNN(){nodeP = NULL;distance = -1;}
};

int isSearched(mat data,mat check)
{
	vec d;
	vec c = check.row(0).t();
	for(int i=0;i<data.n_rows;i++)
	{
		d = data.row(i).t();
		if(all(d==c))
		{
			return 1;
		}
	}
	return 0;
}
void findMaxInNearGroup(KNN *& NearGroup,int K,int & maxInNG)
{
	double maxDis=NearGroup[0].distance;
	int maxN=0;
	for(int i=1;i<K;i++)
	{
		if(maxDis<NearGroup[i].distance)
		{
			maxDis = NearGroup[i].distance;
			maxN = i;
		}
	}
	maxInNG = maxN;
}
//利用kd树搜索k邻近节点
void findNearestK(kdTreeNode * RootNodeP,mat & target,KNN *& NearGroup,int & K)
{
	//如果指针为空,则返回
	if(RootNodeP == NULL) return;
	if(K==0)
	{
	    return;
	}
	/*if(K > RootNodeP.n_rows) 
	{
		K = target.n_rows;
	}	*/
	int KgetN = 0;
	//第一步:利用二叉搜索到子节点并将,节点指针顺序压入堆栈stack中
	kdTreeNode * nodeP = RootNodeP;
	stack <kdTreeNode *> search_path;
	int maxInNGIndex = 0;
	while(nodeP !=NULL && !nodeP->data.empty())//判断遇到空节点,或者节点内无数据,则结束
	{
		search_path.push(nodeP);
		if(target(0,nodeP->splitIndex) <= nodeP->data(0,nodeP->splitIndex))
		{
			nodeP = nodeP->leftP;
		}else{
			nodeP = nodeP->rightP;
		}
	}
	kdTreeNode * firstNodeP = search_path.top();
	NearGroup[0].nodeP= search_path.top();					//将得到的叶子结点记入下来
	NearGroup[0].distance = EuclidianDis(target,NearGroup[0].nodeP->data);	//并计算距离
	KgetN = 1;													//	
	//第二步:然后通过堆栈回溯

	kdTreeNode * back_nodeP;
	while(search_path.empty() == 0)				//直到堆栈为空,即已经回溯完为止
	{
		back_nodeP = search_path.top();
		search_path.pop();
		//若当前回溯节点为叶子节点,只计算与target距离即可。
		//如果距离小于当前记录的最小距离,则更新最小值,否则跳过即可
		if((back_nodeP->leftP == NULL|| back_nodeP->leftP->data.empty())
			&&(back_nodeP->rightP == NULL||back_nodeP->rightP->data.empty()))
		{
			if(back_nodeP != firstNodeP)
			{
				if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
				{							
					if(KgetN<K)
					{
						NearGroup[KgetN].nodeP = back_nodeP;
						NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
						KgetN++;
						if(KgetN==K)
						{
							findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
						}
					}else{
						NearGroup[maxInNGIndex].nodeP = back_nodeP;
						NearGroup[maxInNGIndex].distance = 	EuclidianDis(back_nodeP->data,target);
						findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
					}
				}
			}			
		}else{
			//若当前回溯点带有子树,不仅需要计算距离,并更新
			//而且需要计算目标到超平面的距离,即特征值相减即可
			if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
			{							
				if(KgetN<K)
				{
					NearGroup[KgetN].nodeP = back_nodeP;
					NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
					KgetN++;
					if(KgetN==K)
					{
						findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
					}
				}else{
					NearGroup[maxInNGIndex].nodeP = back_nodeP;
					NearGroup[maxInNGIndex].distance = 	EuclidianDis(back_nodeP->data,target);
					findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
				}
			}
			//计算到超品面距离,如果距离小于当前记录的最小距离
			//则需要进入另一个子树,并像第一步一般,不断二叉搜索,并将搜索路径压入堆栈
			if((KgetN<K)||(abs(target(0,back_nodeP->splitIndex)-back_nodeP->data(0,back_nodeP->splitIndex))<NearGroup[maxInNGIndex].distance))
			{
				kdTreeNode * childNodeP;
				if(target(0,back_nodeP->splitIndex) > back_nodeP->data(0,back_nodeP->splitIndex))
				{
					childNodeP = back_nodeP->leftP;
				}else
				{
					childNodeP = back_nodeP->rightP;
				}				
				while(childNodeP !=NULL && !childNodeP->data.empty())
				{
				    search_path.push(childNodeP);
					if(target(0,childNodeP->splitIndex) <= childNodeP->data(0,childNodeP->splitIndex))
					{
						childNodeP = childNodeP->leftP;
					}else{
						childNodeP = childNodeP->rightP;
					}
				}
			}
		}
	}
}
//线性扫描的方法得到最邻近值,用以验证算法
void scanDis(mat & data,mat & target,mat & nearest,double & distance)
{
	mat a = data.row(0);
	mat b = target;
	mat near = data.row(0);
	double minDis=EuclidianDis(a,b);
	for(int i=1;i<data.n_rows;i++)
	{
		a = data.row(i);
		b = target;
		double dis = EuclidianDis(a,b);
		if(dis<minDis)
		{
			minDis = dis;
			near = data.row(i);			
		}
	}
	distance = minDis;
	nearest = near;
}
void main(void)
{
	mat dataTrain,target,nearest,nearest1,searched;
	int nK = 5;
	KNN * nearstK = new KNN[nK];
	double minDistance,minDistance1;
	dataTrain<<2<<3<<endr
			<<5<<4<<endr
			<<9<<6<<endr
			<<4<<7<<endr
			<<8<<1<<endr
			<<7<<2<<endr;
	target<<2<<4.5<<endr;
	/*dataTrain = randu<mat>(50000,5)*100;
	target = randu<mat>(1,5)*100;*/
	
	kdTreeNode * rootP = new kdTreeNode;
	buildKdTree(rootP,dataTrain,0);
//	inorder_traverse(rootP);

	cout<<"target "<<endl<<target<<endl;

    clock_t arithmeticStart = clock();	
	findNearestK(rootP,target,nearstK,nK);
	clock_t arithmeticEnd = clock();
	cout<<"KD Tree result"<<endl;
	for(int i=0;i<nK;i++)
	{
		cout<<nearstK[i].nodeP->data<<endl;
		cout<<nearstK[i].distance<<endl;
	}
	double runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
	cout<<"KD Tree Running Time : "<<runTime<<" s"<<endl;	

	arithmeticStart = clock();	
	scanDis(dataTrain,target,nearest1,minDistance1);
	arithmeticEnd = clock();
	cout<<"Scan result"<<endl<<nearest1<<minDistance1<<endl;
	runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
	cout<<"Scan Running Time : "<<runTime<<" s"<<endl;
	
	cout<<"press any key to terminate.";
	getchar();
}

 

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值