kd树 java_统计学习方法学习(四)--KNN及kd树的java实现

public classkd_main {public static voidmain(String[] args) {

List nodeList=new ArrayList();

nodeList.add(new Node(new double[]{5,4}));

nodeList.add(new Node(new double[]{9,6}));

nodeList.add(new Node(new double[]{8,1}));

nodeList.add(new Node(new double[]{7,2}));

nodeList.add(new Node(new double[]{2,3}));

nodeList.add(new Node(new double[]{4,7}));

nodeList.add(new Node(new double[]{4,3}));

nodeList.add(new Node(new double[]{1,3}));

kd_main kdTree=newkd_main();

Node root=kdTree.buildKDTree(nodeList,0);newBinaryTreeOrder().preOrder(root);for(Node node : nodeList) {

System.out.println(node.toString()+"-->"+node.left.toString()+"-->"+node.right.toString());

}

System.out.println(root);

System.out.println(kdTree.searchKNN(root,new Node(new double[]{2.1,3.1}),2));

System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),1));

System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),3));

System.out.println(kdTree.searchKNN(root,new Node(new double[]{6,1}),5));

}/*** 构建kd树 返回根节点

*@paramnodeList

*@paramindex

*@return

*/

public Node buildKDTree(List nodeList,intindex)

{if(nodeList==null || nodeList.size()==0)return null;

quickSortForMedian(nodeList,index,0,nodeList.size()-1);//中位数排序

Node root=nodeList.get(nodeList.size()/2);//中位数 当做根节点

root.dim=index;

List leftNodeList=new ArrayList();//放入左侧区域的节点 包括包含与中位数等值的节点-_-

List rightNodeList=new ArrayList();for(Node node:nodeList)

{if(root!=node)

{if(node.getData(index)<=root.getData(index))

leftNodeList.add(node);//左子区域 包含与中位数等值的节点

elserightNodeList.add(node);

}

}//计算从哪一维度切分

int newIndex=index+1;//进入下一个维度

if(newIndex>=root.data.length)

newIndex=0;//从0维度开始再算

root.left=buildKDTree(leftNodeList,newIndex);//添加左右子区域

root.right=buildKDTree(rightNodeList,newIndex);if(root.left!=null)

root.left.parent=root;//添加父指针

if(root.right!=null)

root.right.parent=root;//添加父指针

returnroot;

}/*** 查询最近邻

*@paramroot kd树

*@paramq 查询点

*@paramk

*@return

*/

public List searchKNN(Node root,Node q,intk)

{

List knnList=new ArrayList();

searchBrother(knnList,root,q,k);returnknnList;

}/*** searhchBrother

*@paramknnList

*@paramk

*@paramq*/

public void searchBrother(List knnList, Node root, Node q, intk) {//Node almostNNode=root;//近似最近点

Node leafNNode=searchLeaf(root,q);double curD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径

leafNNode.distance=curD;

maintainMaxHeap(knnList,leafNNode,k);

System.out.println("leaf1"+leafNNode.getData(leafNNode.parent.dim));while(leafNNode!=root)

{if (getBrother(leafNNode)!=null) {

Node brother=getBrother(leafNNode);

System.out.println("brother1"+brother.getData(brother.parent.dim));if(curD>Math.abs(q.getData(leafNNode.parent.dim)-leafNNode.parent.getData(leafNNode.parent.dim))||knnList.size()

{//这样可能在另一个子区域中存在更加近似的点

searchBrother(knnList,brother, q, k);

}

}

System.out.println("leaf2"+leafNNode.getData(leafNNode.parent.dim));

leafNNode=leafNNode.parent;//返回上一级

double rootD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径

leafNNode.distance=rootD;

maintainMaxHeap(knnList,leafNNode,k);

}

}/*** 获取兄弟节点

*@paramnode

*@return

*/

publicNode getBrother(Node node)

{if(node==node.parent.left)returnnode.parent.right;else

returnnode.parent.left;

}/*** 查询到叶子节点

*@paramroot

*@paramq

*@return

*/

publicNode searchLeaf(Node root,Node q)

{

Node leaf=root,next=null;int index=0;while(leaf.left!=null || leaf.right!=null)

{if(q.getData(index)

{

next=leaf.left;//进入左侧

}else if(q.getData(index)>leaf.getData(index))

{

next=leaf.right;

}else{//当取到中位数时 判断左右子区域哪个更加近

if(q.computeDistance(leaf.left)

next=leaf.left;elsenext=leaf.right;

}if(next==null)break;//下一个节点是空时 结束了

else{

leaf=next;if(++index>=root.data.length)

index=0;

}

}returnleaf;

}/*** 维护一个k的最大堆

*@paramlistNode

*@paramnewNode

*@paramk*/

public void maintainMaxHeap(List listNode,Node newNode,intk)

{if(listNode.size()

{

maxHeapFixUp(listNode,newNode);//不足k个堆 直接向上修复

}else if(newNode.distance

maxHeapFixDown(listNode,newNode);

}

}/*** 从上往下修复 将会覆盖第一个节点

*@paramlistNode

*@paramnewNode*/

private void maxHeapFixDown(ListlistNode,Node newNode)

{

listNode.set(0, newNode);int i=0;int j=i*2+1;while(j

{if(j+1

j++;//选出子结点中较大的点,第一个条件是要满足右子树不为空

if(listNode.get(i).distance>=listNode.get(j).distance)break;

Node t=listNode.get(i);

listNode.set(i, listNode.get(j));

listNode.set(j, t);

i=j;

j=i*2+1;

}

}private void maxHeapFixUp(ListlistNode,Node newNode)

{

listNode.add(newNode);int j=listNode.size()-1;int i=(j+1)/2-1;//i是j的parent节点

while(i>=0)

{if(listNode.get(i).distance>=listNode.get(j).distance)break;

Node t=listNode.get(i);

listNode.set(i, listNode.get(j));

listNode.set(j, t);

j=i;

i=(j+1)/2-1;

}

}/*** 使用快排进进行一个中位数的查找 完了之后返回的数组size/2即中位数

*@paramnodeList

*@paramindex

*@paramleft

*@paramright*/@Testprivate void quickSortForMedian(List nodeList,int index,int left,intright)

{if(left>=right || nodeList.size()<=0)return;

Node kn=nodeList.get(left);double k=kn.getData(index);//取得向量指定索引的值

int i=left,j=right;//控制每一次遍历的结束条件,i与j相遇

while(i

{//从右向左找一个小于i处值的值,并填入i的位置

while(nodeList.get(j).getData(index)>=k && i

j--;

nodeList.set(i, nodeList.get(j));//从左向右找一个大于i处值的值,并填入j的位置

while(nodeList.get(i).getData(index)<=k && i

i++;

nodeList.set(j, nodeList.get(i));

}

nodeList.set(i, kn);if(i==nodeList.size()/2)return ;//完成中位数的排序了,但并不是完成了所有数的排序,这个终止条件只是保证中位数是正确的。去掉该条件,可以保证在递归的作用下,将所有的树//将所有的数进行排序

else if(i

{

quickSortForMedian(nodeList,index,i+1,right);//只需要排序右边就可以了

}else{

quickSortForMedian(nodeList,index,left,i-1);//只需要排序左边就可以了

}//for (Node node : nodeList) {//System.out.println(node.getData(index));//}

}

}

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值