1. KD树构建
KD树本质是一颗二叉树,构建首先要选择一个根节点,一般选择方差较大维度的中间数节点作为根节点
列如:有以下数据集
(4, 7), (9, 6), (8, 1), (2, 3), (5, 4), (7, 2)
其中 x 的方差为 2.4094720491334933 y的方差为 2.1147629234082532
这里 x 的方差 大于 y的方差 我们选择x轴 排序后 的中间数 作为 根节点
以 x 排序 [(2., 3.) (4., 7.) (5., 4.) (7., 2.) (8., 1.) (9., 6.)] 这里 因为是偶数 中间数有两个 (5,4) 和 (7, 2) 随便选取一个就行 我们选择 (7, 2) 作为根节点 在根节点左边的数列 分到左子树,在根节点右边的数列分到右子树
此时的树结构如图所示
第二步
将左右数列 分别 以 y 排序 去中间数
左边数列 (2., 3.) (5., 4.) (4., 7.) 中间数为 (5, 4) 插入 左子树
右边数列 (8,1) (9, 6) 将中间数 (9,6) 插入到 右子树 如果有两个数 选择较大的那一个
此时的树结构如图所示
第三步
节点 (5,4)左节点还剩下 (2., 3.) 右节点剩下 (4., 7.) 分别插入到左右子树
右边 只剩 (8,1)插入右子树
至此 一颗KD树就构建完毕
此时的树结构如图所示
2. KD树搜索
KD树搜索算法较为复杂 需要仔细研究 (个人也是学习记录, 有错误请指正 谅解)
KD 树搜索 分为两个大部分 一个是向下搜索 一个向上回溯
1. 向下搜索
从根节点开始,首先对比 x 的点大小,如果小于走左边反之右边
第二次 对比 y 的大小 ,小于走左边反之右边
如此往复 x y x y 一直走到子节点为空结束
走到 最后节点, 计算点之间的距离,保存该距离 r 将其记录最近节点,将当前点标记为以访问,然后向上回溯
2. 向上回溯
以 搜索点为圆心,以r为半径,判断是否与当前节点某轴线相交如果相交 则此节点的子节点可能有比 记录的最近节点 还要近的节点,将当前节点记录以访问,然后向下搜索
如果没有与当前节点某维度相交, 那么该子节点则没有比记录最近节点更近的节点,往上回溯
一直往上回溯 一直到根节点 结束搜索
以下 直接上代码
import copy
import math
import numpy as np
class ThreeNode:
left_node = None
right_node = None
p_node = None
value = None
is_access = False
class KdtSearchPathItem:
node = None
dist = 0
def __str__(self):
return "距离 " + str(self.dist) + " 值" + str(self.node.value)
class Kdt:
def __init__(self):
self.dtype = np.dtype([('x', float), ('y', float)])
self.data = np.array([(4, 7), (9, 6), (8, 1), (2, 3), (5, 4), (7, 2)], dtype=self.dtype)
self.search_paths = []
# 计算标准差
self.x_std = np.std(self.data['x'])
self.y_std = np.std(self.data['y'])
# 选择一个维度 的中间值作为根节点 选择方差最大的维度
x_var = np.var(self.data['x'])
y_var = np.var(self.data['y'])
if x_var > y_var:
axis = 'x'
else:
axis = 'y'
# 初始化根节点
three = ThreeNode()
# 递归创建子节点
self.build_three(self.data, axis, three)
self.three = three
def get_axis(self, axis):
if axis == 'x':
new_axis = 'y'
else:
new_axis = 'x'
return new_axis
def build_three(self, n_list, axis, three):
n_list = np.sort(n_list, order=axis)
middle_index = self.get_middle(len(n_list))
# 将中位值作为树的节点值
three.value = n_list[middle_index]
# 计算下一次排序和切割的维度
new_axis = self.get_axis(axis)
# 初始化左右子树
three.left_node = ThreeNode()
three.left_node.p_node = three
three.right_node = ThreeNode()
three.right_node.p_node = three
# 如果 list 大小大于2 正常切割
if len(n_list) > 2:
left_n_list = n_list[0: middle_index]
right_n_list = n_list[middle_index + 1:]
self.build_three(left_n_list, new_axis, three.left_node)
self.build_three(right_n_list, new_axis, three.right_node)
# list大小等于2 那么只剩下一个元素了 (已经分割过一次 每一次分割都会减少一个元素) 添加到左叶节点
elif len(n_list) == 2:
self.build_three([n_list[0]], new_axis, three.left_node)
# 找到中间数 偶数采用较大的那个数
@staticmethod
def get_middle(length):
if length == 1:
return 0
if length == 2:
return 1
if length % 2 == 0:
index = length / 2
else:
index = (length - 1) / 2
return int(index)
def _search_down(self, value, axis, three):
# 1. 判断是否是最后一个节点 如果不是则继续往下搜索
if (three.left_node is None or three.left_node.value is None) and (three.right_node is None or three.right_node.value is None):
self._append_search_paths(three, value)
three.is_access = True
for i in self.search_paths:
print(i)
# 往上搜索
return self._search_up(value, self.get_axis(axis), three.p_node)
# 正常节点 依次往下搜索
self._append_search_paths(three, value)
# 判断往哪个方向搜索
if value[axis] < three.value[axis]:
# 有可能 左子树为空 如果为空直接上跳
if three.left_node is None or three.left_node.value is None:
three.is_access = True
return self._search_up(value, self.get_axis(axis), three.p_node)
return self._search_down(value, self.get_axis(axis), three.left_node)
else:
if three.right_node is None or three.right_node.value is None:
three.is_access = True
return self._search_up(value, self.get_axis(axis), three.p_node)
return self._search_down(value, self.get_axis(axis), three.right_node)
def _search_up(self, value, axis, three):
three.is_access = True
# 判断是否与轴线相交
if axis == 'x':
point = np.array((three.value['x'], value['y']), dtype=self.dtype)
else:
point = np.array((value['x'], three.value['y']), dtype=self.dtype)
is_intersect = self._is_intersect(value, point, self.search_paths[0].dist)
print(is_intersect, value, point, self.search_paths[0].dist)
if is_intersect:
# 相交
# 如果两边都被访问过 则网上跳
if three.left_node.is_access and three.right_node.is_access:
if three.p_node is None:
return
return self._search_up(value, self.get_axis(axis), three.p_node)
# 有一边没有被访问
access_node = three.left_node if not three.left_node.is_access else three.right_node
return self._search_down(value, self.get_axis(axis), access_node)
else:
# 不相交 直接往上跳
if three.p_node is None:
return
return self._search_up(value, self.get_axis(axis), three.p_node)
pass
pass
def _append_search_paths(self, three, value):
item = KdtSearchPathItem()
item.node = three
item.dist = self._get_dist(three.value, value)
self.search_paths.append(item)
# 从小到大排序
self.search_paths.sort(key=lambda x: x.dist)
'''
判断以x点画圆 半径为r 是否相交与点y
'''
def _is_intersect(self, x, y, r):
if x['x'] == y['x']:
return abs(x['y'] - y['y']) < r
if x['y'] == y['y']:
return abs(x['x'] - y['x']) < r
d = math.sqrt((abs(x['x'] - y['x']) ** 2) * (abs(x['y'] - y['y']) ** 2))
return d < r
"""
标准化欧氏距离计算
"""
def _get_dist(self, v1, v2):
d = math.sqrt(((v1['x'] - v2['x']) / self.x_std) ** 2 + ((v1['y'] - v2['y']) / self.y_std) ** 2)
return d
def search(self, value):
self._search_down(np.array(value, dtype=self.dtype), 'x', self.three)
# copy一个副本 用于返回
r = copy.copy(self.search_paths)
# 将标记过的节点 重置
for i in self.search_paths:
i.is_access = False
self.search_paths = []
return r
kdt = Kdt()
paths = kdt.search((3, 20))
print("最近距离是: ", paths[0])
程序输出
C:\Users\Administrator\PycharmProjects\stu01\venv\Scripts\python.exe C:/Users/Administrator/PycharmProjects/stu01/算法/KNN/2KDT.py
距离 6.16125544670923 值(4., 7.)
距离 7.6112568765057285 值(5., 4.)
距离 8.671977042761826 值(7., 2.)
False (3., 20.) (3., 4.) 6.16125544670923
True (3., 20.) (7., 20.) 6.16125544670923
True (3., 20.) (7., 20.) 6.16125544670923
最近距离是: 距离 6.16125544670923 值(4., 7.)
原理看完还是得敲一敲代码 不然记忆不够深刻