前言
实现KNN算法时,当样本维度很大或训练样本数量巨大时,实现对数据的快速搜索,对提高计算效率有很大意义。
实现KNN最简单的方法是线性扫描,通俗来讲,就是要把输入数据和已有的所有训练数据都计算出来,比对距离大小,从众多样本中找K个近邻样本。为了提高查找计算效率,可以采用特殊的数据结构存储训练数据,以减少计算距离的次数。
代码原创,转载请注明出处。
一、KD树是什么?
KD树是一种查询索引结构,广泛应用于数据库索引中。
换成低维空间比较容易思考:
给定一个数组[99,56,85,23,66,12,101],用这组数将数轴划分成几部分:
(,12)(12,23)(23,56)(56,66)(66,85)(85,99)(99,101)(101,)
(12,23,56,66,85,99,101)
给一个数40,找到它所在的区间,我们会将40和数组中的中位数66比较,所以40在左区间
(12,23,56)
再将40和左区间的中位数23比较,于是40落在了右区间(23,56)
这里就是将66,23,99看做了一个节点,去划分整个一维空间,通过和它们比较,来判断输入的数应在哪一范围中。
当数据提升到二维空间,可以理解为用两个维度轮流作为标准,根据训练数据划分样本空间(构造矩形)。
- 初始以x1划分样本(2,4,5,7,8,9),取7做划分,于是二叉树根节点为(7,2)左右两侧划分的样本分别为(5,4)(2,3)(4,7)
(9,6)(8,1) - 做第二次切分,以x2划分样本,左右两侧都取x2轴坐标的中间值切分空间,左支取4,右支取6
- 第三次切分以x1划分,只有一点,取该点为节点划分
至此完成了kd树的生成!
二、构造kd树
代码实现
关于二叉树的其他相关知识需要学习数据结构,这里不做赘述。
# 定义二叉树节点类型
class TreeNode(object):
def __init__(self,loc = None,tag = 0,ln = None,lr = None):
self.loc = loc # 自身节点坐标
self.tag = tag # 自身标签
self.left = ln # 左叶子
self.right = lr # 右叶子
def showinfo(self):