版权声明:本文为原创文章,未经博主允许不得用于商业用途。
kd-tree算法的原理参考知乎这篇文章,这里使用java实现了二维kd树。主要代码如下:
class KDTree{
protected KDNode ROOT;
protected ArrayList<KDNode> nn;
protected double[] nnsdist;
protected int k;
public KDTree(ArrayList<dPoint> pointlist){
ROOT = new KDNode();
BuildTree(pointlist, ROOT, 0, pointlist.size(), true, true);
}
public void DFS()
{
DFS(ROOT.leftChild);
}
public ArrayList<dPoint> KNN(int k, dPoint center){
this.k = k;
this.nn = new ArrayList<>();
this.nnsdist = new double[k];
KNN(center, ROOT.leftChild);
ArrayList<dPoint> knn = new ArrayList<>();
for(KDNode x:nn)
{
knn.add(x.dividePoint.Copy());
}
return knn;
}
protected void KNN(dPoint center, KDNode root)
{
if(root==null)
return;
// System.out.println(root.dividePoint.x+","+root.dividePoint.y+(root.xaxis?",x,":",y,")+nn.size());
//if goto left part
boolean leftpart = root.xaxis && center.x<root.dividePoint.x || !root.xaxis && center.y<root.dividePoint.y;
if(leftpart)
KNN(center, root.leftChild);
else
KNN(center, root.rightChild);
double cdist = (center.x-root.dividePoint.x)*(center.x-root.dividePoint.x)+
(center.y-root.dividePoint.y)*(center.y-root.dividePoint.y);
if(nn.size()<k)
{
nnsdist[nn.size()] = cdist;
nn.add(root);
}
else
{
int maxidx = 0;
for(int i=1;i<k;i++)
if(nnsdist[i]>nnsdist[maxidx])
maxidx = i;
//if nearer, replace the maximun distance nn point with current point
if(nnsdist[maxidx]>cdist)
{
nn.set(maxidx, root);
nnsdist[maxidx]=cdist;
}
if(root.xaxis && nnsdist[maxidx]<Math.abs(center.x-root.dividePoint.x))
return;
if(!root.xaxis && nnsdist[maxidx]<Math.abs(center.y-root.dividePoint.y))
return;
}
if(leftpart)
KNN(center, root.rightChild);
else
KNN(center, root.leftChild);
}
protected void DFS(KDNode root)
{
System.out.println(root.dividePoint.x+","+root.dividePoint.y);
if(root.leftChild!=null)
DFS(root.leftChild);
if(root.rightChild!=null)
DFS(root.rightChild);
}
protected void BuildTree(ArrayList<dPoint> pointlist,KDNode root , int start, int end, boolean xaxis, boolean left){
if(start>=end)
return;
int mid = start;
if(start+1<end) {
if (xaxis)
pointlist.subList(start, end).sort((dPoint p1, dPoint p2) -> Double.compare(p1.x, p2.x));
else
pointlist.subList(start, end).sort((dPoint p1, dPoint p2) -> Double.compare(p1.y, p2.y));
mid = (start + end) / 2;
}
KDNode node = new KDNode();
node.dividePoint = new dPoint(pointlist.get(mid).x, pointlist.get(mid).y);
node.xaxis = xaxis;
if(left)
root.leftChild = node;
else
root.rightChild = node;
BuildTree(pointlist,node,start,mid,!xaxis,true);
BuildTree(pointlist,node,mid+1,end,!xaxis,false);
}
}
可视化效果:
具体代码见github