书接上文,我们这一张来讨论KD Tree的最近邻搜索。首先来看一下《统计学习方法》书中给出的算法描述:
我们用网上很常见的例子进行分析:
(7,2),(5,4),(9,6),(2,3),(4,7),(8,1)
待测节点:(2,4.5)
我们先不管在坐标上是怎么划分的,先看看在KD Tree上是如何遍历的。
我们先来进行分析:
(1)在kd树中找出包含目标点x的叶节点。
这一步比较简单,首先根据树的每一层分割所用的维度(Flag)进行搜索,直到找到叶子节点位置(Position),搜索路径图表示如下:
只需要比较该维度坐标值和树节点上该维度坐标值大小即可:
1>比较目标节点(node)和树上节点(TNode)对应维度(dia)上坐标值的大小:
1.1>如果node.get(dia) <= TNode.get(dia):递归搜索左孩子。
1.2>否则递归搜索右孩子。
2>当搜索到叶子节点时,返回之并退出。
private KDTreeNode GetLeaf(KDTreeNode Tnode, KDNode node) {
int dia = Tnode.flag%KDNode.dimension;
if (Tnode.getValue() == null) { return Tnode.getFather();
//如果找到叶子节点还未找到,那就把他父节点设为最近邻
} else if (node.get(dia) < Tnode.getValue().get(dia)) {
return GetLeaf(Tnode.getLeft(), node);
}else
return GetLeaf(Tnode.getRight(), node);
}
2)以此叶子节点为“当前最近点“(CurNearest)。
3)递归向上会退,在每个节点进行如下操作:
a)如果该节点保存的实例点比当前最近最近点距离目标节点更近,则更新“当前最近点“。
b)当前最近点一定存在于该节点(SplitPotint)的一个子节点对应的区域。检查该子节点的父节点的另一子节点对应的区域是否有更近的节点。具体的检查另一子节点对应的区域是否与以目标点为球心,目标点到当前最近点距离为半径的超球相交。
b.1>如果相交,则可能在另一子节点对应的区域存在距离目标点更近的点,移动到另一子节点,递归搜索。
b.2>如果不相交,向上退回。
4>当退回到跟节点时,搜索结束,返回当前最近点。
这段描述是该算法的核心与难点,实现过程中遇到众多难题,其中之一便是“向上退回“时进入死循环。咱们先看描述,再来说我遇到的问题。
其实并不难理解,
第一步:从一个叶子节点向上搜索,如果父节点离node近,则更新之。
第二步:比较当前的两点(Curnearest&node)间距离和node到超平面的距离的大小(是否相交),
- 如果不相交,那很好说,直接搜索父节点即可。
- 如果相交,问题来了,就需要进入其兄弟节点看看有没有比CurNearest更近的节点。如果有,则更新之,如果没有则返回父节点。
-注意,此时返回父节点后,可知该父节点的两个子节点都没有比CurNearest更近的节点了,此时向上退回到爷爷节点。我在这一步的实现上出现了问题。问题就是在递归调用时,无法判断该节点是退回到的节点还是向下搜索到的节点,也就是说无法判断该节点的两个子节点是否都已经被搜索过了。即如下图所示,无法区分是绿线搜索到还是蓝线搜索到,后面我来说一下我的解决方法。
我使用了路径栈(pathStack)的方式存储所有搜索过的节点,这样判断该节点是否在栈中就可以区分啦!
- 使用方式:搜素一个节点,首先判断该节点是否在栈中,如果在则说明是回退搜索到的,则将该节点弹出栈,并回退到其父节点处。如果不在栈中,则说明是向下搜索到的该节点,则入栈该节点,并进行a,b操作。
实现代码如下:
KDTreeNode GetNearest(KDNode node, KDTreeNode nearest, KDTreeNode spiltPoint) {
KDTreeNode CurNearest = nearest;
if (!pathStack.contains(spiltPoint)) {//如果该节点未被便利过,则进行遍历,避免父子之间死循环
if (spiltPoint == null)
return nearest;
// if (spiltPoint == root) {//遍历终止
// if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue()))
// return root;//如果root节点比当前最近节点近则返回root节点
// else //否则返回当前最近节点
// return nearest;
// }
pathStack.push(spiltPoint);
if (node.distance(nearest.getValue()) > node.distance(spiltPoint.getValue())) {
CurNearest = spiltPoint;
}//如果当前节点距离待测数据比较近,更新之。
if (node.isTangential(CurNearest.getValue(), spiltPoint.getValue(), spiltPoint.getFlag())) {
//发生相交的情况
//去查找最邻近节点的兄弟节点
//if (spiltPoint.getLeft() == nearest && spiltPoint