查找到该点属于的区域之后回溯
import heapq
import numpy as np
from sklearn.preprocessing import StandardScaler
class Node():
# KD 树节点
def __init__(self):
self.father = None
self.left = None
self.right = None
self.feature = None
self.split = None
@property
def brother(self):
"""
获取兄弟节点
"""
if self.father is None:
ret = None
else:
if self.father.left is self:
ret = self.father.right
else:
ret = self.father.left
return ret
def __str__(self):
return "feature: %s, split: %s" % (str(self.feature), str(self.split))
class KDTree():
#KD树
def __init__(self):
self.root = Node()
self.scaler = None
def build_tree(self,X,y):
"""
根据给定的数据集构建KD树
"""
#标准化X
self.scaler = StandardScaler().fit(X)
X = self.scaler.transform(X)
nd = self.root # 当前需要确定的节点
idxs = range(len(X)) # 当前点需要分开的区域包含的数据集下标
# BFS构建KD树
que = [(nd,idxs)] # 队列节点里是当前搜到的点和他包含的区域
while que:
nd, idxs = que.pop(0) # 弹出队头
n = len(idxs)
# 如果是叶节点,没啥能分了就返回
if(n == 1):
nd.split = (X[idxs[0]],y[idxs[0]])
continue
#不是叶节点
# (1)选择特征
if(nd.father == None):
nd.feature = 0
else:
nd.feature = (nd.father.feature+1)%(np.shape(X)[1])
# (2)根据特征选出中位数,获取他的下标
k = n//2
col = map(lambda i:(i,X[i][nd.feature]),idxs) # 把序列号与特征抽出来
sorted_idxs = map(lambda x:x[0],sorted(col,key = lambda x:x[1])) #col按照特征值排序,并返回排序后的下标数组
median_idx = list(sorted_idxs)[k] #拿出来中位数对应下标
nd.split = (X[median_idx],y[median_idx])
# (3)根据中位数将点分到左右儿子上
idxs_left = []
idxs_right = []
split_val = X[median_idx][nd.feature]
for idx in idxs:
xi = X[idx][nd.feature]
if idx == median_idx:
continue # 就是你让我改了一下午???
if xi < split_val:
idxs_left.append(idx)
else:
idxs_right.append(idx)
#(4) 如果左右儿子还能分,将他们加到队列中
if idxs_left != []:
nd.left = Node()
nd.left.father = nd
que.append((nd.left,idxs_left))
if idxs_right != []:
nd.right = Node()
nd.right.father = nd
que.append((nd.right,idxs_right))
def dfs(self,Xi,nd):
"""
从nd开始dfs直到叶节点,返回叶节点(可能的最近点)
"""
while nd.right or nd.left:
if nd.right is None:
nd = nd.left
elif nd.left is None:
nd = nd.right
else:
if Xi[nd.feature] <= nd.split[0][nd.feature]:
nd = nd.left
else:
nd = nd.right
return nd
def n_n_search(self,Xi,k=1):
"""
返回与Xi最邻近的K个元素
"""
# 标准化
Xi = self.scaler.transform([Xi])
Xi = Xi[0]
# 新建最小堆
h = []
#(0) 从根DFS到叶子节点找到第一个可能的最近点,初始化最优解和搜索队列
nd_cur= self.dfs(Xi,self.root)
que = [(self.root, nd_cur)]
# 向上搜索
while que:
nd_root, nd_cur = que.pop(0)
while 1:
dist = np.linalg.norm(nd_cur.split[0]-Xi)**2 # 当前节点到Xi的欧氏距离,更新最优解和判断相交都要用
# (1) 如果比堆顶更优,更新堆
if len(h) < k:
heapq.heappush(h,(-dist,nd_cur.split))
else:
tmp = heapq.heappop(h)
if tmp[0] < -dist:
heapq.heappush(h,(-dist,nd_cur.split))
else:
heapq.heappush(h,tmp)
# (2) 如果是根节点,继续搜索下一个可能的最近点
if nd_cur is nd_root:
break
# (3) 如果不是根节点,检查兄弟节点区域是否相交,相交的话DFS兄弟节点,并将新的可能的最近点加到队列中,然后接着向上搜索
nd_bro = nd_cur.brother
if nd_bro is not None:
dist_hyper = (Xi[nd_bro.father.feature]-nd_bro.split[0][nd_bro.father.feature]) **2 #到超平面的距离 #就是你让我改了一下午???
if dist > dist_hyper:
_nd_best = self.dfs(Xi,nd_bro)
que.append((nd_bro,_nd_best))
nd_cur = nd_cur.father
return h
X = [[2,3],[4,7],[5,4],[7,2],[8,1],[9,6]]
y = [1,2,3,4,5,6]
kdtree = KDTree()
kdtree.build_tree(X,y)
test = list(kdtree.n_n_search([3,6],3))
test = list(map(lambda x:(-x[0],x[1][1]),test))
print(test)