输入数据可以自己调,设置要查的point和k值后返回最邻近的k个point
import numpy as np
class KD_Node:
def __init__(self, point=None, split=None, leftNode=None, rightNode=None):
"""
:param point:数据点
:param split:分割维度
:param LL:左儿子
:param RR:右儿子
"""
self.point = point
self.split = split
self.leftNode = leftNode
self.rightNode = rightNode
def setData():
group = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
return group
def createKDTree(group):
"""
构造KD树
:param group: 数据集
:return:
"""
length = len(group)
if length == 0:
return
dimension = len(group[0])
# 维度
split = 0
# 最大方差
max_var = 0
# 获得方差最大的维度
for i in range(dimension):
temp = []
for data in group:
temp.append(data[i])
var = cal_variance(temp)
if var > max_var:
max_var = var
split = i
group.sort(key=lambda x: x[split])
point = group[length // 2]
root = KD_Node(point, split)
root.leftNode = createKDTree(group[0: (length // 2)])
# 注意右儿子要加一
root.rightNode = createKDTree(group[(length // 2 + 1): length])
return root
def cal_variance(dataList):
"""
计算数据集的方差,D(x) = E(X^2) - [E(X)]^2
:param dataList:某一维度下的所有数据
:return:
"""
length = len(dataList)
array = np.array(dataList)
sum1 = array.sum()
e1 = sum1 / length
e2 = (array * array).sum() / length
return e2 - e1 * e1
def findKN(point, root, k):
"""
kd树搜索函数,搜索最近的点
:param point:要查找的点
:param root:kd树的树根
:param k:返回邻近点的个数
:return:
"""
if k < 1:
return
ret = [root.point]
max_dist = cal_dist(point, ret[-1])
nodeList = []
temp = root
while temp:
nodeList.append(temp)
dist = cal_dist(point, temp.point)
if len(ret) < k:
ret.append(temp.point)
sortP(ret, point)
if dist > max_dist:
max_dist = dist
elif dist < max_dist and len(ret) >= k:
# 去除最大
ret.pop()
ret.append(temp.point)
sortP(ret, point)
max_dist = cal_dist(point, ret[-1])
ss = temp.split
if point[ss] < temp.point[ss]:
temp = temp.leftNode
else:
temp = temp.rightNode
while nodeList:
bac_point = nodeList.pop()
ss = bac_point.split
# 判断是否要进入父节点的另一个子空间进行搜索
# 并不是判断距离就要进去,只要空间中有那个圈就要进去
if abs(point[ss] - bac_point.point[ss]) < max_dist:
if point[ss] >= bac_point.point[ss]:
temp = bac_point.leftNode
else:
temp = bac_point.rightNode
if temp:
nodeList.append(temp)
dist = cal_dist(temp.point, point)
if len(ret) < k:
ret.append(temp.point)
sortP(ret, point)
if dist > max_dist:
max_dist = dist
elif dist < max_dist and len(ret) >= k:
ret.pop()
ret.append(temp.point)
sortP(ret, point)
max_dist = cal_dist(point, ret[-1])
return ret
def sortP(group, point):
for i in range(len(group) - 1):
for j in range(i + 1, len(group)):
if cal_dist(group[j], point) < cal_dist(group[i], point):
temp = group[j]
group[j] = group[i]
group[i] = temp
return group
def cal_dist(point1, point2):
"""
计算两个点的欧氏距离
:param point1:
:param point2:
:return:
"""
ret = 0
for i in range(len(point1)):
ret += (point1[i] - point2[i]) ** 2
return ret ** (1 / 2)