为什么要用kd树实现k近邻算法
与简单的KNN实现不同的是利用kd树可以显著减小距离的计算次数。简单KNN实现时要计算目标点与所有其他点的距离,而kd树不必计算所有距离。
简单KNN可以参考k近邻法原理及编程实现。
构造kd树
在这里《统计学习方法》用到了平衡kd树,在学习平衡kd树之前可以先学习一下比较简单的平衡二叉树。这里我们直接按照《统计学习方法》中构造kd树的步骤来写代码:
构造节点类:
class Node:
"""构造节点类,节点包括当前节点data,当前切分维度sp,当前节点的左右子节点left和right"""
def __init__(self, data, sp=0, left=None, right=None):
self.data = data
self.sp= sp # sp代表用第sp维数据进行划分左右子树
self.left = left
self.right = right
建立kd树:
# dataset代表原本的数据集,sp代表用第sp维数据进行划分左右子树,dimension代表数据的维度
def create(dataset, sp):
if len(dataset) == 0: # 递归的出口
return None
# 对当前维度进行排序,当前维度的计算如上面算法描述中所说的那样l=j(mod k)+1,
# 即用j对k求余数再加1
# 第一次的sp=0,故以第零维为准进行升序排列
dataset = sorted(dataset, key=lambda x: x[sp])
mid = len(dataset)//2 # 找到以当前维度为准进行排序之后中间的数
# split by the median #
dat = dataset[mid] # 新节点的值为dat
# 重复上面的步骤进行递归,直到到达递归的出口
# 由于在实际的数据中样本是从x[0]开始的,而不是x[1],
# 所以这里与书中的算法描述有些不一样,这里没有加1
# 这里(sp+1) % dimension 对当前深度进行了更新,
# dimension是个固定值,表示数据集的维数
return Node(dat, sp, create(dataset[:mid], (sp+1) % dimension),\
create(dataset[mid+1:], (sp+1) % dimension))
构造过程如下所示:
搜索kd树找出最近邻:
# x接受的是目标点, near_k接受的是k近邻中的k,p=2代表计算距离时使用L2范数
def nearest(self, x, near_k=1, p=2):
# use the max heap builtin library heapq
# init the elements with -inf, and use the minus distance for comparison
# the top of the max heap is the min distance.
self.knn = [(-np.inf, None)]*near_k # -np.inf表示负无穷大
def visit(node):
if not node == None: # 递归的出口
# 计算目标点与分离超平面之间的距离
dis = x[node.sp] - node.data[node.sp]
# 如果目标点与分离超平面的距离小于0,则访问它的左子树,否则访问它的右子树
# 直到到达递归出口,即可获得当前与目标点距离最近的点
visit(node.left if dis < 0 else node.right)
# 计算一下这个当前最短距离
curr_dis = np.linalg.norm(x-node.data, p)
# 将最短距离放进堆中
heapq.heappushpop(self.knn, (-curr_dis, node))
# 比较当前最短距离和目标点到分离超平面的距离
# 如果到分离超平面的距离短的话就访问另一个节点
if -(self.knn[0][0]) > abs(dis): # 递归出口
visit(node.right if dis < 0 else node.left)
visit(self.root) # self.root表示根节点
self.knn = np.array(
[i[1].data for i in heapq.nlargest(near_k, self.knn)])
return self.knn
完整代码:
import numpy as np
import heapq
class Node:
def __init__(self, data, sp=0, left=None, right=None):
self.data = data
self.sp = sp # sp代表
self.left = left
self.right = right
class KDTree:
def __init__(self, data): # data代表原本的样本集
k = data.shape[1] # k代表数据的维数
print("k=", k)
def create(dataset, sp):
print("it_sp=", sp)
if len(dataset) == 0: # 递归的出口
return None
# sort by current dimension
# sp=0 故以第零维为准进行升序排列
dataset = sorted(dataset, key=lambda x: x[sp])
print("dataset=", dataset)
mid = len(dataset)//2
# split by the median
dat = dataset[mid]
print("dat=", dat)
print("sp = ", sp)
return Node(dat, sp, create(dataset[:mid], (sp+1) % k),\
create(dataset[mid+1:], (sp+1) % k))
self.root = create(data, 0)
print("self_root=", self.root.data)
def nearest(self, x, near_k=1, p=2): # x接受的是目标点, near_k接受的是k近邻中的k
print("near_k=", near_k)
# use the max heap builtin library heapq
# init the elements with -inf, and use the minus distance for comparison
# the top of the max heap is the min distance.
self.knn = [(-np.inf, None)]*near_k # -np.inf表示负无穷大
print("self_knn=", self.knn)
def visit(node):
if not node == None: # 递归的出口
# cal the distance to the split point, i.e. the hyperplane
print("visit_sp=", node.sp)
dis = x[node.sp] - node.data[node.sp]
# visit the child node recursively
# if returned, we get the current nearest point
visit(node.left if dis < 0 else node.right)
# cal the distance to the current nearest point
print("node_data=", node.data)
curr_dis = np.linalg.norm(x-node.data, p)
# push the minus distance to the heap
heapq.heappushpop(self.knn, (-curr_dis, node))
# compare the distance to the hyperplane with the min distance
# if less, visit another node.
if -(self.knn[0][0]) > abs(dis):
visit(node.right if dis < 0 else node.left)
visit(self.root)
self.knn = np.array(
[i[1].data for i in heapq.nlargest(near_k, self.knn)])
return self.knn
if __name__ == "__main__":
from pylab import *
data = array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
kdtree = KDTree(data) # create the kd_tree
target = array([7.5, 3])
kdtree.nearest(target, 2) # find nearest in kd_tree
# print("kdtree=", kdtree.nearest(target, 2))
plot(*data.T, 'o')
plot(*target.T, '.r')
plot(*kdtree.knn.T, 'r+')
show()