public classkd_main {public static voidmain(String[] args) {
List nodeList=new ArrayList();
nodeList.add(new Node(new double[]{5,4}));
nodeList.add(new Node(new double[]{9,6}));
nodeList.add(new Node(new double[]{8,1}));
nodeList.add(new Node(new double[]{7,2}));
nodeList.add(new Node(new double[]{2,3}));
nodeList.add(new Node(new double[]{4,7}));
nodeList.add(new Node(new double[]{4,3}));
nodeList.add(new Node(new double[]{1,3}));
kd_main kdTree=newkd_main();
Node root=kdTree.buildKDTree(nodeList,0);newBinaryTreeOrder().preOrder(root);for(Node node : nodeList) {
System.out.println(node.toString()+"-->"+node.left.toString()+"-->"+node.right.toString());
}
System.out.println(root);
System.out.println(kdTree.searchKNN(root,new Node(new double[]{2.1,3.1}),2));
System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),1));
System.out.println(kdTree.searchKNN(root,new Node(new double[]{2,4.5}),3));
System.out.println(kdTree.searchKNN(root,new Node(new double[]{6,1}),5));
}/*** 构建kd树 返回根节点
*@paramnodeList
*@paramindex
*@return
*/
public Node buildKDTree(List nodeList,intindex)
{if(nodeList==null || nodeList.size()==0)return null;
quickSortForMedian(nodeList,index,0,nodeList.size()-1);//中位数排序
Node root=nodeList.get(nodeList.size()/2);//中位数 当做根节点
root.dim=index;
List leftNodeList=new ArrayList();//放入左侧区域的节点 包括包含与中位数等值的节点-_-
List rightNodeList=new ArrayList();for(Node node:nodeList)
{if(root!=node)
{if(node.getData(index)<=root.getData(index))
leftNodeList.add(node);//左子区域 包含与中位数等值的节点
elserightNodeList.add(node);
}
}//计算从哪一维度切分
int newIndex=index+1;//进入下一个维度
if(newIndex>=root.data.length)
newIndex=0;//从0维度开始再算
root.left=buildKDTree(leftNodeList,newIndex);//添加左右子区域
root.right=buildKDTree(rightNodeList,newIndex);if(root.left!=null)
root.left.parent=root;//添加父指针
if(root.right!=null)
root.right.parent=root;//添加父指针
returnroot;
}/*** 查询最近邻
*@paramroot kd树
*@paramq 查询点
*@paramk
*@return
*/
public List searchKNN(Node root,Node q,intk)
{
List knnList=new ArrayList();
searchBrother(knnList,root,q,k);returnknnList;
}/*** searhchBrother
*@paramknnList
*@paramk
*@paramq*/
public void searchBrother(List knnList, Node root, Node q, intk) {//Node almostNNode=root;//近似最近点
Node leafNNode=searchLeaf(root,q);double curD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
leafNNode.distance=curD;
maintainMaxHeap(knnList,leafNNode,k);
System.out.println("leaf1"+leafNNode.getData(leafNNode.parent.dim));while(leafNNode!=root)
{if (getBrother(leafNNode)!=null) {
Node brother=getBrother(leafNNode);
System.out.println("brother1"+brother.getData(brother.parent.dim));if(curD>Math.abs(q.getData(leafNNode.parent.dim)-leafNNode.parent.getData(leafNNode.parent.dim))||knnList.size()
{//这样可能在另一个子区域中存在更加近似的点
searchBrother(knnList,brother, q, k);
}
}
System.out.println("leaf2"+leafNNode.getData(leafNNode.parent.dim));
leafNNode=leafNNode.parent;//返回上一级
double rootD=q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
leafNNode.distance=rootD;
maintainMaxHeap(knnList,leafNNode,k);
}
}/*** 获取兄弟节点
*@paramnode
*@return
*/
publicNode getBrother(Node node)
{if(node==node.parent.left)returnnode.parent.right;else
returnnode.parent.left;
}/*** 查询到叶子节点
*@paramroot
*@paramq
*@return
*/
publicNode searchLeaf(Node root,Node q)
{
Node leaf=root,next=null;int index=0;while(leaf.left!=null || leaf.right!=null)
{if(q.getData(index)
{
next=leaf.left;//进入左侧
}else if(q.getData(index)>leaf.getData(index))
{
next=leaf.right;
}else{//当取到中位数时 判断左右子区域哪个更加近
if(q.computeDistance(leaf.left)
next=leaf.left;elsenext=leaf.right;
}if(next==null)break;//下一个节点是空时 结束了
else{
leaf=next;if(++index>=root.data.length)
index=0;
}
}returnleaf;
}/*** 维护一个k的最大堆
*@paramlistNode
*@paramnewNode
*@paramk*/
public void maintainMaxHeap(List listNode,Node newNode,intk)
{if(listNode.size()
{
maxHeapFixUp(listNode,newNode);//不足k个堆 直接向上修复
}else if(newNode.distance
maxHeapFixDown(listNode,newNode);
}
}/*** 从上往下修复 将会覆盖第一个节点
*@paramlistNode
*@paramnewNode*/
private void maxHeapFixDown(ListlistNode,Node newNode)
{
listNode.set(0, newNode);int i=0;int j=i*2+1;while(j
{if(j+1
j++;//选出子结点中较大的点,第一个条件是要满足右子树不为空
if(listNode.get(i).distance>=listNode.get(j).distance)break;
Node t=listNode.get(i);
listNode.set(i, listNode.get(j));
listNode.set(j, t);
i=j;
j=i*2+1;
}
}private void maxHeapFixUp(ListlistNode,Node newNode)
{
listNode.add(newNode);int j=listNode.size()-1;int i=(j+1)/2-1;//i是j的parent节点
while(i>=0)
{if(listNode.get(i).distance>=listNode.get(j).distance)break;
Node t=listNode.get(i);
listNode.set(i, listNode.get(j));
listNode.set(j, t);
j=i;
i=(j+1)/2-1;
}
}/*** 使用快排进进行一个中位数的查找 完了之后返回的数组size/2即中位数
*@paramnodeList
*@paramindex
*@paramleft
*@paramright*/@Testprivate void quickSortForMedian(List nodeList,int index,int left,intright)
{if(left>=right || nodeList.size()<=0)return;
Node kn=nodeList.get(left);double k=kn.getData(index);//取得向量指定索引的值
int i=left,j=right;//控制每一次遍历的结束条件,i与j相遇
while(i
{//从右向左找一个小于i处值的值,并填入i的位置
while(nodeList.get(j).getData(index)>=k && i
j--;
nodeList.set(i, nodeList.get(j));//从左向右找一个大于i处值的值,并填入j的位置
while(nodeList.get(i).getData(index)<=k && i
i++;
nodeList.set(j, nodeList.get(i));
}
nodeList.set(i, kn);if(i==nodeList.size()/2)return ;//完成中位数的排序了,但并不是完成了所有数的排序,这个终止条件只是保证中位数是正确的。去掉该条件,可以保证在递归的作用下,将所有的树//将所有的数进行排序
else if(i
{
quickSortForMedian(nodeList,index,i+1,right);//只需要排序右边就可以了
}else{
quickSortForMedian(nodeList,index,left,i-1);//只需要排序左边就可以了
}//for (Node node : nodeList) {//System.out.println(node.getData(index));//}
}
}