构造平衡kd树与使用kd树进行最近邻搜索
//数据点
public class Node {
int[] x;
Node left;
Node right;
Node parent;
double distance;
public Node(int[] x){
this.x = x;
}
}
//排序
import java.util.ArrayList;
import java.util.List;
public class Sort {
public List<Node> sort(List<Node> list, int index){
List<Node> list1 = new ArrayList<>();
while(list.size()>0){
Node n = list.get(0);
for(int i =1; i <list.size(); i ++) {
if(list.get(i).x[index]<n.x[index]){
n=list.get(i);
}
}
list1.add(n);
list.remove(n);
}
return list1;
}
}
//对排序之后的list取中位数
import java.util.List;
public class Median {
public Node median(List<Node> list){
Node root = list.get(list.size()/2);
return root;
}
}
//构造kd树
import java.util.ArrayList;
import java.util.List;
public class Huafen {
public Node huafen(List<Node> list, int index){
if(list.size()==0 || list==null){
return null;
}
Sort s = new Sort();
Median m = new Median();
List<Node> list1 = s.sort(list,index);
Node root =m.median(list1);
List<Node> left = new ArrayList<>();
List<Node> right = new ArrayList<>();
for(int i = 0; i < list1.size(); i ++) {
if (list1.get(i).x[index] < root.x[index]) {
left.add(list1.get(i));
}
else if(list1.get(i).x[index] > root.x[index]) {
right.add(list1.get(i));
}
}
if(index+1<list1.get(0).x.length) {
index = index +1;
}
else index = 0;
root.left = huafen(left,index);
root.right = huafen(right,index);
if(root.left!= null){
root.left.parent=root;
}
if(root.right!=null){
root.right.parent=root;
}
return root;
}
}
//输出
public class Output {
public void output(Node root) {
if (root != null) {
System.out.println(root.x[0] + "," + root.x[1]);
output(root.left);
output(root.right);
}
}
}
public class SearchPoint {
//从上到下搜索最近点
public Node UpToDownSearch(Node root, double[] x, int index){
while(root != null ) {
if (x[index] <= root.x[index]) {
if(root.left!=null) {
root = root.left;
}
else
break;
} else if(x[index] > root.x[index]){
if(root.right!=null) {
root = root.right;
}
else
break;
}
}
double distance = 0;
for(int i = 0; i < x.length; i++){
distance = distance + Math.pow(root.x[i]-x[i],2);
}
root.distance = distance;
return root;
}
//从下到上搜索最近点
public Node DownToUpSearch(Node root,Node roott, double[] x, int index){
Node root1 = root.parent;
Node node = null;
double distance = 0;
for(int i = 0; i < x.length; i ++){
distance = distance + Math.pow(root1.x[i]-x[i],2);//计算当前最近点父节点的距离
}
root1.distance = distance;
//判断是否进入另一个区域
Node point = null;
if(root==root1.left){
point = root1.right;
}
else{
point = root1.left;
}
//与平面相交,进入另一个区域
if(Math.abs(x[index+1]-root1.x[index+1])<roott.distance) {
double distance2 = 0;
for (int k = 0; k < x.length; k++) {
distance2 = distance2 + Math.pow(x[k] - point.x[k], 2);
}
point.distance = distance2;
if(point.distance<roott.distance) {//如果另一个兄弟节点的距离小于当前最近点的距离,则更新
if(point.distance<root1.distance) {
roott = point;
}
else
roott = root1;
}else {
if(root1.distance<roott.distance){
roott =root1;
}
}
}
if(root.parent.parent!=null){
DownToUpSearch(root.parent,roott,x,index);
}
return roott;
}
//例子
import java.util.ArrayList;
import java.util.List;
public class solution {
public static void main(String[] args) {
Node node1 = new Node(new int[]{2,3});
Node node2 = new Node(new int[]{5,4});
Node node3 = new Node(new int[]{9,6});
Node node4 = new Node(new int[]{4,7});
Node node5 = new Node(new int[]{8,1});
Node node6 = new Node(new int[]{7,2});
List<Node> list = new ArrayList<>();
list.add(node1);
list.add(node2);
list.add(node3);
list.add(node4);
list.add(node5);
list.add(node6);
Huafen hu = new Huafen();
Sort s = new Sort();
Node root = hu.huafen(s.sort(list,0),0);
Output o = new Output();
o.output(root);
SearchPoint searchPoint = new SearchPoint();
double[]x = {8,2};
Node r = searchPoint.UpToDownSearch(root,x,0);
Node node = searchPoint.DownToUpSearch(r,r,x,0);
System.out.println("最近点为:");
System.out.println(node.x[0] + "," + node.x[1]);
}
}