-
写在前面
在一番摸索之下,博主利用国庆的时间,系统地了解k邻近算法。从第一次接触,到完整地用c++代码实现利用kd树来进行 k邻近搜索。弄清了很多细节,但也可能有很多不足之处,小伙伴们尽情板砖。
算法的主体都有详细的程序,只是将样本放在矩阵库(armadillo)中。当然小伙伴们用opencv库更好,博主也打算转用opencv库。
-
k邻近分类算法简述
首先我们有m个样本,每个样本有n个特征,而且每个样本有自己的分类标签。此时,给出一个测试样例,如何去判断该样例是属于什么分类呢。k邻近算法给我们提供了一个思路:找到距离测试样本最近的k个样本。这k个样本大多数属于哪个类别,那么该测试样例属于这个类别的可能性就比较大。这里的距离就使用欧式距离(平方开根号)。
那么如何去找到这k个样本呢?首先想到的办法就是把m个样本全部算一遍距离。这个方案在样本m数据量很大时,会比较耗时。现在有一种方案,就是使用kd树去寻找这k个样本。其精髓与二分查找类似。
-
kd树的建立
kd树的结构与二叉树类似
简述步骤:
- 我们有个样本矩阵m行,n列。m是指有m个样本,n指有n维特征。那么首先找到第一个特征的中间值所在的样本(即找到矩阵第一列的中间值所在的行)。
- 将这个样本作为根节点。然后将剩余的样本分成两部分,第一部分中样本的第一个特征值都小于根节点的第一个特征值。第二部分中样本的第一个特征值都大于根节点的第一个特征值。
- 然后我们同样对这两个矩阵进行分割。此时我们用第二列特征值,同样找到中间值。以此不断分割,直到所有的样本都在kd树的节点上。
我们发现步骤2与步骤三可以用相同的函数实现,这不是勾引我们用递归吗。。当然这与二叉树不是类似吗。
具体实现:直接上代码,代码中用了armadillo矩阵库(只是辅助用,不包含算法),和大多数矩阵用法类似,具体可以参考
http://arma.sourceforge.net/docs.html
//kd树节点结构体
struct kdTreeNode{
mat data; //节点数据向量 1*n n为特征数
kdTreeNode * leftP; //左子树指针
kdTreeNode * rightP; //右子树指针
kdTreeNode * fatherP; //父亲节点指针,竟然没用到
int splitIndex; //当前节点分割数据的特征索引,目前是轮换特征
kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//构造函数
};
//使用递归建立kd树
/*
nodeP 结构体指针
data 输入数据矩阵,m*n m为样本数 n为特征数
depth 节点深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
//如果输入的数据矩阵为空,则返回
if(data.n_rows == 0)
{
return;
}
//如果只有一个样本,则直接将数据给节点
if(data.n_rows == 1)
{
nodeP->data = data.row(0);
return;
}
//此时样本数大于等于2
int splitFeatureIndex = depth % data.n_cols; //计算分割数据的特征索引
nodeP->splitIndex = splitFeatureIndex;
vec splitFeature = data.col(splitFeatureIndex); //取得特征列向量
vec splitOrder = sort(splitFeature); //排序
double medianValue = splitOrder[data.n_rows/2]; //取中值
mat subsetLeft,subsetRight; //左右矩阵
//接下来是将样本矩阵分为左右两个矩阵
for(int i=0;i<data.n_rows;i++)
{
if(nodeP->data.empty()&&splitFeature[i] == medianValue)
{
//将中间样本给节点
nodeP->data = data.row(i);
}else{
if(splitFeature[i] < medianValue)
{
subsetLeft = join_vert( subsetLeft,data.row(i));
}else{
subsetRight = join_vert( subsetRight,data.row(i));
}
}
}
nodeP->leftP = new kdTreeNode;
nodeP->leftP->fatherP = nodeP;
nodeP->rightP = new kdTreeNode;
nodeP->rightP->fatherP = nodeP;
//递归进入左右子树
buildKdTree(nodeP->leftP,subsetLeft,depth+1);
buildKdTree(nodeP->rightP,subsetRight,depth+1);
}
建立完之后,我们可以写个遍历程序验证下
//中序遍历
void inorder_traverse(kdTreeNode * nodeP)
{
if(nodeP == NULL || nodeP->data.empty())
{
return;
}
inorder_traverse(nodeP->leftP);
cout<<nodeP->data<<endl;
inorder_traverse(nodeP->rightP);
}
//前序遍历
void preorder_traverse(kdTreeNode * nodeP)
{
if(nodeP == NULL || nodeP->data.empty())
{
return;
}
cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
preorder_traverse(nodeP->leftP);
preorder_traverse(nodeP->rightP);
}
-
最邻近搜索
现有一个测试目标,为1*n的矩阵,n为特征数,现在就在样本矩阵找到与这个目标距离最近的样本。
搜索就不用迭代了,这样也好理解
具体步骤:
- 第一步:利用二叉搜索到子节点并将,节点指针顺序压入堆栈stack中
使用测试样本的第一个特征值,与根节点比较,若特征值小于或等于根节点的第一个特征值,则下一步在左子树搜索,反之则在右子树。如此循环直至找到叶子节点。并且在这个过程中,不断将得到的节点压入堆栈中,这就形成我们的初步搜索路径,按照堆栈先进后出的原则,堆栈最顶端的元素将是叶子节点。
- 第二步:通过保存在堆栈中的搜索路径回溯
首先计算出最后的叶子结点与目标的距离保存为最短距离minDIstance,并将它设置成最邻近样本
(1)从堆栈中取出一个节点
(2)若是叶子结点,则只有一个步骤
1.计算与目标的距离,若小于保存的minDistance,则将mindistance替换为当前计算的距离,并将当前样本设置为最邻近样本。若大于则跳过。
若不是叶子节点,则有三个步骤
1.此步骤与叶子结点的那个步骤相同
2.计算目标与当前节点划分的超平面的距离(这个可以这么理解:假如在二维直角坐标系中计算某个点与x轴或与x轴平行的一条线的距离,怎么算,这个很显然吧~,同样的,与超平面的距离你应该知道怎么算了)
3.假如目标到超平面的距离大于记录的minDistance。则不执行任何操作,等着下一次回溯。反之若小于minDistance,则说明在当前节点的另一个子树中可能有更近的点。这样,我们找到另一个子树的节点,将它压入到堆栈中(即加入到搜索路径中)并且按照 第一步 的方法,一边往下二叉搜索,一边将遇到的节点压入堆栈中。然后就等着下一次回溯(即从堆栈中取节点)
(3)不断循环步骤(2)直到堆栈中的数据都取出,即搜索路径都走了一遍。
代码到贴k邻近再贴吧
-
k 邻近搜索
网上一直没找到关于k邻近的算法步骤,我自己在最邻近的基础上修改了下,发现可行,不知与官方的算法是否有差距。
其中主体算法步骤与最邻近相似。
就说明一下不同之处:
第一点不同:相较于最邻近只需找一个样本,k邻近算法中,我们定义有个结构体数组来记录k个样本
第二点不同:我们不管他是不是真的最近的k个样本先找到k个再说。所以在收集到k个样本之前,无视最邻近算法中是否小于minDistance这个条件,统统收集到数组中,直到找满k个样本。
第三点不同:找到k个样本之后,立刻找出k个样本中的最大距离的样本,此距离记为maxD,因为他最不靠谱。。恢复最邻近算法中的条件,找到一个比maxD距离更小的样本,就可以将它替换。此时重新计算最maxD。直到堆栈为空。
-
最后
把所有代码都贴上吧。所有代码都在一个文件中。想要运行需要加矩阵库armadillo。x64
如果需要工程文件的,我也上传了。里面包含了矩阵库的lib,dll,头文件,可以直接运行,不需要配置。用的是vs2012 x64
https://download.csdn.net/download/qq_32478489/10707377
#include <iostream>
#include <time.h>
#include <armadillo>
#include <stack>
using namespace std;
using namespace arma;
//kd树节点结构体
struct kdTreeNode{
mat data; //节点数据向量 1*n n为特征数
kdTreeNode * leftP; //左子树指针
kdTreeNode * rightP; //右子树指针
kdTreeNode * fatherP; //父亲节点指针,竟然没用到
int splitIndex; //当前节点分割数据的特征索引,目前是轮换特征
kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//构造函数
};
//使用递归建立kd树
/*
nodeP 结构体指针
data 输入数据矩阵,m*n m为样本数 n为特征数
depth 节点深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
//如果输入的数据矩阵为空,则返回
if(data.n_rows == 0)
{
return;
}
//如果只有一个样本,则直接将数据给节点
if(data.n_rows == 1)
{
nodeP->data = data.row(0);
return;
}
//此时样本数大于等于2
int splitFeatureIndex = depth % data.n_cols; //计算分割数据的特征索引
nodeP->splitIndex = splitFeatureIndex;
vec splitFeature = data.col(splitFeatureIndex); //取得特征列向量
vec splitOrder = sort(splitFeature); //排序
double medianValue = splitOrder[data.n_rows/2]; //取中值
mat subsetLeft,subsetRight; //左右矩阵
//接下来是将样本矩阵分为左右两个矩阵
for(int i=0;i<data.n_rows;i++)
{
if(nodeP->data.empty()&&splitFeature[i] == medianValue)
{
//将中间样本给节点
nodeP->data = data.row(i);
}else{
if(splitFeature[i] < medianValue)
{
subsetLeft = join_vert( subsetLeft,data.row(i));
}else{
subsetRight = join_vert( subsetRight,data.row(i));
}
}
}
nodeP->leftP = new kdTreeNode;
nodeP->leftP->fatherP = nodeP;
nodeP->rightP = new kdTreeNode;
nodeP->rightP->fatherP = nodeP;
//递归进入左右子树
buildKdTree(nodeP->leftP,subsetLeft,depth+1);
buildKdTree(nodeP->rightP,subsetRight,depth+1);
}
//中序遍历
void inorder_traverse(kdTreeNode * nodeP)
{
if(nodeP == NULL || nodeP->data.empty())
{
return;
}
inorder_traverse(nodeP->leftP);
cout<<nodeP->data<<endl;
inorder_traverse(nodeP->rightP);
}
//前序遍历
void preorder_traverse(kdTreeNode * nodeP)
{
if(nodeP == NULL || nodeP->data.empty())
{
return;
}
cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
preorder_traverse(nodeP->leftP);
preorder_traverse(nodeP->rightP);
}
//计算欧氏距离
double EuclidianDis(mat & a,mat & b)
{
vec v1 = a.row(0).t();
vec v2 = b.row(0).t();
vec c = v1-v2;
return norm(c,2);
}
struct KNN{
kdTreeNode * nodeP;
double distance;
KNN(){nodeP = NULL;distance = -1;}
};
int isSearched(mat data,mat check)
{
vec d;
vec c = check.row(0).t();
for(int i=0;i<data.n_rows;i++)
{
d = data.row(i).t();
if(all(d==c))
{
return 1;
}
}
return 0;
}
void findMaxInNearGroup(KNN *& NearGroup,int K,int & maxInNG)
{
double maxDis=NearGroup[0].distance;
int maxN=0;
for(int i=1;i<K;i++)
{
if(maxDis<NearGroup[i].distance)
{
maxDis = NearGroup[i].distance;
maxN = i;
}
}
maxInNG = maxN;
}
//利用kd树搜索k邻近节点
void findNearestK(kdTreeNode * RootNodeP,mat & target,KNN *& NearGroup,int & K)
{
//如果指针为空,则返回
if(RootNodeP == NULL) return;
if(K==0)
{
return;
}
/*if(K > RootNodeP.n_rows)
{
K = target.n_rows;
} */
int KgetN = 0;
//第一步:利用二叉搜索到子节点并将,节点指针顺序压入堆栈stack中
kdTreeNode * nodeP = RootNodeP;
stack <kdTreeNode *> search_path;
int maxInNGIndex = 0;
while(nodeP !=NULL && !nodeP->data.empty())//判断遇到空节点,或者节点内无数据,则结束
{
search_path.push(nodeP);
if(target(0,nodeP->splitIndex) <= nodeP->data(0,nodeP->splitIndex))
{
nodeP = nodeP->leftP;
}else{
nodeP = nodeP->rightP;
}
}
kdTreeNode * firstNodeP = search_path.top();
NearGroup[0].nodeP= search_path.top(); //将得到的叶子结点记入下来
NearGroup[0].distance = EuclidianDis(target,NearGroup[0].nodeP->data); //并计算距离
KgetN = 1; //
//第二步:然后通过堆栈回溯
kdTreeNode * back_nodeP;
while(search_path.empty() == 0) //直到堆栈为空,即已经回溯完为止
{
back_nodeP = search_path.top();
search_path.pop();
//若当前回溯节点为叶子节点,只计算与target距离即可。
//如果距离小于当前记录的最小距离,则更新最小值,否则跳过即可
if((back_nodeP->leftP == NULL|| back_nodeP->leftP->data.empty())
&&(back_nodeP->rightP == NULL||back_nodeP->rightP->data.empty()))
{
if(back_nodeP != firstNodeP)
{
if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
{
if(KgetN<K)
{
NearGroup[KgetN].nodeP = back_nodeP;
NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
KgetN++;
if(KgetN==K)
{
findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
}
}else{
NearGroup[maxInNGIndex].nodeP = back_nodeP;
NearGroup[maxInNGIndex].distance = EuclidianDis(back_nodeP->data,target);
findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
}
}
}
}else{
//若当前回溯点带有子树,不仅需要计算距离,并更新
//而且需要计算目标到超平面的距离,即特征值相减即可
if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
{
if(KgetN<K)
{
NearGroup[KgetN].nodeP = back_nodeP;
NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
KgetN++;
if(KgetN==K)
{
findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
}
}else{
NearGroup[maxInNGIndex].nodeP = back_nodeP;
NearGroup[maxInNGIndex].distance = EuclidianDis(back_nodeP->data,target);
findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
}
}
//计算到超品面距离,如果距离小于当前记录的最小距离
//则需要进入另一个子树,并像第一步一般,不断二叉搜索,并将搜索路径压入堆栈
if((KgetN<K)||(abs(target(0,back_nodeP->splitIndex)-back_nodeP->data(0,back_nodeP->splitIndex))<NearGroup[maxInNGIndex].distance))
{
kdTreeNode * childNodeP;
if(target(0,back_nodeP->splitIndex) > back_nodeP->data(0,back_nodeP->splitIndex))
{
childNodeP = back_nodeP->leftP;
}else
{
childNodeP = back_nodeP->rightP;
}
while(childNodeP !=NULL && !childNodeP->data.empty())
{
search_path.push(childNodeP);
if(target(0,childNodeP->splitIndex) <= childNodeP->data(0,childNodeP->splitIndex))
{
childNodeP = childNodeP->leftP;
}else{
childNodeP = childNodeP->rightP;
}
}
}
}
}
}
//线性扫描的方法得到最邻近值,用以验证算法
void scanDis(mat & data,mat & target,mat & nearest,double & distance)
{
mat a = data.row(0);
mat b = target;
mat near = data.row(0);
double minDis=EuclidianDis(a,b);
for(int i=1;i<data.n_rows;i++)
{
a = data.row(i);
b = target;
double dis = EuclidianDis(a,b);
if(dis<minDis)
{
minDis = dis;
near = data.row(i);
}
}
distance = minDis;
nearest = near;
}
void main(void)
{
mat dataTrain,target,nearest,nearest1,searched;
int nK = 5;
KNN * nearstK = new KNN[nK];
double minDistance,minDistance1;
dataTrain<<2<<3<<endr
<<5<<4<<endr
<<9<<6<<endr
<<4<<7<<endr
<<8<<1<<endr
<<7<<2<<endr;
target<<2<<4.5<<endr;
/*dataTrain = randu<mat>(50000,5)*100;
target = randu<mat>(1,5)*100;*/
kdTreeNode * rootP = new kdTreeNode;
buildKdTree(rootP,dataTrain,0);
// inorder_traverse(rootP);
cout<<"target "<<endl<<target<<endl;
clock_t arithmeticStart = clock();
findNearestK(rootP,target,nearstK,nK);
clock_t arithmeticEnd = clock();
cout<<"KD Tree result"<<endl;
for(int i=0;i<nK;i++)
{
cout<<nearstK[i].nodeP->data<<endl;
cout<<nearstK[i].distance<<endl;
}
double runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
cout<<"KD Tree Running Time : "<<runTime<<" s"<<endl;
arithmeticStart = clock();
scanDis(dataTrain,target,nearest1,minDistance1);
arithmeticEnd = clock();
cout<<"Scan result"<<endl<<nearest1<<minDistance1<<endl;
runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
cout<<"Scan Running Time : "<<runTime<<" s"<<endl;
cout<<"press any key to terminate.";
getchar();
}