KNN算法是将待测样本与训练样本的特征进行比较,取k个与待测样本最接近的训练样本(如计算欧氏距离),其中k个样本中大多数属于哪一类别便也将待测样本分类为哪一类别(最大分类决策)。
Kd树是高效的实现KNN算法的一种实现方式,其中的k代表有k个维度。
Kd树的每个节点都可以看做一颗子树,他们拥有一致的属性和方法,因此kd树的构造和搜索都以递归的方式进行。
首先构造节点类,每个节点都有根值、左孩子、右孩子三个基本属性。class node:
def __init__(self):
#根节点
self.root = None
#左孩子
self.left = None
#右孩子
self.right = None
然后开始构造kd树,struct函数需要传入数据集,树的高度和父节点。构造出的节点属性包括节点所在高度,节点对应的搜索数据的轴,父节点,节点值(第n轴排序得到的中位数所在节点),标签,左子树(由中位数左侧的数据递归构建)和右子树(由中位数右侧的数据递归构建)。def struct(self,x,height=0,father=None):
#数据集维度
dimension = x.shape[1] - 1
#节点高度
self.height = height
#节点对应的轴
self.axis = self.height % dimension
#父节点
self.father = father
#按第axis轴排序
x = np.array(sorted(x,key=lambda a:a[self.axis]))
#选取中位数作为切分点
median = x.shape[0] // 2
#根节点保存数据
self.root = x[median][:-1]
#保存数据标签
self.target = int(x[median][-1])
#递归构建kd树,其中若一侧无剩余元素,则对应孩子为None
if x[:median].shape[0]:
self.left = node()
self.left.struct(x[:median],height=self.height+1,father=self)
else:
self.left = None
if x[median+1:].shape[0]:
self.right = node()
self.right.struct(x[median+1:],height=self.height+1,father=self)
else:
self.right = None
这样,一棵kd树便递归的构造完成了,通过每个节点的root属性访问节点的数据值,left属性访问左节点,right属性访问右节点。
下面开始搜索kd树。
首先以最近邻为例进行搜索。
用特定变量来保存搜索到的当前最近点和根节点,保存的根节点用于之后判断是否搜索到了树顶。def search(self,x):
#保存当前最近点