写在前面
1.关于KNN的简要叙述见博文https://blog.csdn.net/kodoshinichi/article/details/106819524
KNN算法的重点在于三个参数的选择:K的取值、距离度量方法以及分类决策方法
2.之所以要从KNN进行优化发展成KDTree是因为需要对k维空间进行一个快速检索(优化KNN检索的开销)。
3.KDTree本质上还是二叉树,表示对k维空间的一个划分,其每个节点对应于k维空间划分中的一个超矩形区域。利用KDTree优化检索的过程是因为它可以省去对大部分数据的搜索过程,从而减少搜索的计算量。
一.KD树的构造
核心过程:
用不断垂直于坐标轴的超平面将k维空间进行切分,构成一系列的k维超矩形区域。
最终结果:
建立的一棵二叉树,树中的每个节点就是对应一个k维的超平面矩形区域。
参数确定:
每一次二叉树的分支就相当于在K维空间中对选定的坐标轴做一条垂线,但是需要明确该直线经过哪一点才能将直线定下来。
——选择训练实例点在选定坐标轴上的中位数作为切分点,这样得到的kd树是平衡的。
划分图示:
二.KDTree的搜索
核心过程:先找到包含所要查找的叶节点对应的区域,再回溯检查其相邻区域中是否有更优解。
我个人理解把KDTree的搜索过程划分成两个部分,首先自上而下进行指定叶节点的搜寻(树的递归搜寻算法即可满足),再自下而上进行回溯检查。
自上而下:从根节点出发,递归地向下访问kd树,若目标点x当前维的坐标小于切分点的坐标,则移动到左子节点,否则移动到右子节点,直到子节点为叶节点。
自下而上:自上而下的结果已经返回了一个当前最优解
若有其他最优点,那么也一定处在以目标点为中心并经过了当前最近邻点的超球体的内部。
回溯的过程主要是检查该节点的父节点的另一侧分支,如果父节点的另一个子节点的超矩形区域与超球体相交,就说明另一侧父节点中存在更近的实例点,更新当前最近邻点。
直到回溯的过程不再存在更优近邻点,则算法停止。
KDTree搜索的图示
三.KDTree对时空数据最近邻点的索引
1.实验要求:给定的两个数据集为时空数据集,实验只采用最后两维数据,分别代表数据点的经纬度值,需要对这部分数据建立索引结构,并进行最近邻点的索引。
INPUT: 节点50;OUTPUT: 该节点的最近邻节点以及算法构建查询时间
2.实验代码
#coding:utf-8
import numpy as np
from time import *
class KD_node(object):
#定义的kd树节点
def __init__(self, point = None, split = None,index = None, LL = None, RR = None):
#节点值
self.point = point
#节点分割维度
self.split = split
# 该坐标点在原数据集的下标号
self.index = index
#节点左孩子
self.left = LL
#节点右孩子
self.right = RR
'''
建立KD树
'''
def createKDTree(root, data_list):
#print(type(data_list))
#start是传入的数据集的第一个元素在原数据集中的下标
length = len(data_list)#第一次传入的的确是ndarray,但之后递归的时候传入的就是列表了
if length == 0:
return
dimension = len(data_list[0])-1#去掉标号所在的维度
max_var = 0
split = 0
for i in range(dimension):
ll = []
for t in data_list:
ll.append(t[i+1])#取数据的时候也要略过标号那一维
var = computerVariance(ll)
if var > max_var:
max_var = var
split = i
#以最大方差的点为维度,进行划分
data_list = sorted(data_list, key = lambda x : x[split+1])
#找中位下标
point = data_list[int(length / 2)][1:]#取后两维的坐标
index =int(data_list[int(length / 2)][0])#取第一维的id号
#print(index)
root = KD_node(point, split, index)
#递归建立左子树
root.left = createKDTree(root.left, data_list[0:int(length / 2)])
#递归建立右子树
root.right = createKDTree(root.right, data_list[int(length / 2) + 1 : length])
return root
#计算方差,对方差公式进行转化了,每个数据点的平方均值减去数据均值的平方
def computerVariance(arraylist):
#arraylist = array(arraylist)
for i in range(len(arraylist)):
arraylist[i] = float(arraylist[i])
length = len(arraylist)
sum1 = sum(arraylist)
array2 = [arraylist[i]*arraylist[i] for i in range(length)]
sum2 = sum(array2)
mean = sum1 / length
variance = sum2 / length - mean ** 2
return variance
'''
基于KD树进行索引
'''
#用于计算欧式距离
def computerDistance(pt1, pt2):
sum = 0.0
for i in range(len(pt1)):
sum = sum + (pt1[i] - pt2[i]) ** 2#注意第一维被标号占用,传入的只有两维,不过循环遍历的控制次数要修改
return sum ** 0.5
#query中保存着最近k节点,先进行K近邻查询,再从中挑出最优解
def findNN(root, query,k):
min_dist = computerDistance(query,root.point)
node_K = []#node_K中和nodeList保持同步,存入计算的距离的结果
nodeList = []#存放自上而下进行搜索的过程中途径的父节点
idList = []#存放最优解在原始数据集中的下标
temp_root = root
#为了方便,在找到叶子节点同时,把所走过的父节点的距离都保存下来,下一次回溯访问就只需要访问子节点,不需要再访问一遍父节点。
while temp_root:
nodeList.append(temp_root)#将当前父节点存入,只要还有子节点,那么在判断的位置上存入的都是父节点
dd = computerDistance(query,temp_root.point)
if len(node_K) < k:
node_K.append(dd)
idList.append(temp_root.index)
else :
max_dist = max(node_K)
if dd < max_dist:#当在进行KNN查询时,列表中已经存入当前K个距离最小的解,下一个解进入的条件就是它是否小于最大距离
index = node_K.index(max_dist)#求表中最大值对应的索引
del(node_K[index])
del(idList[index])
node_K.append(dd)
idList.append(temp_root.index)
ss = temp_root.split
#找到最靠近的叶子节点
#在当前划分轴的那个维度上进行比较,找到下一个进行比较的节点
if query[ss] <= temp_root.point[ss]:
temp_root = temp_root.left
else:
temp_root = temp_root.right
print('node_k :',node_K)
print('idList :',idList)
#回溯访问父节点
while nodeList:
back_point = nodeList.pop()
ss = back_point.split
print('父亲节点 : ',back_point.point,'维度 :',back_point.split,'节点标号 :',back_point.index)
max_dist = max(node_K)
print('该节点到查询节点的距离 :',computerDistance(back_point.point,query))
#若满足进入该父节点的另外一个子节点的条件
#算法的描述是以查询点为中心,以中心点和当前最近点的距离为半径做一个圆,与父节点相交的话,那么该父节点的另一个子节点有可能是更近点
#代码实现起来就是,以查询点到当前父节点的划分轴的距离是否小于KNN列表中的最大值,最大值是因为K有多个的话,只要能进到表中即可
if len(node_K) < k or abs(query[ss] - back_point.point[ss]) < max_dist :
#进入另外一个子节点
if query[ss] <= back_point.point[ss]:
temp_root = back_point.right
else:
temp_root = back_point.left
if temp_root:
nodeList.append(temp_root)
curDist = computerDistance(temp_root.point,query)
print('curDist :',curDist)
#如果当前点满足入表的条件,但是表已经满了的时候
if max_dist > curDist and len(node_K) == k:
index = node_K.index(max_dist)
#把当前表内最大的数据点给移出,再把新的点放入表中
del(node_K[index])
del(idList[index])
node_K.append(curDist)
idList.append(temp_root.index)
elif len(node_K) < k:
#如果表没有满,那么这个点无理由进入,因为一定是当前 前K个最近的点
node_K.append(curDist)
idList.append(temp_root.index)
return node_K,nodeList,idList
if __name__ == "__main__":
index = 0
data = np.loadtxt("BJ/real.txt")[:, -2:]#导入数据
id_sum = data.shape[0]
data_id = np.loadtxt("BJ/real.txt")[:, 0].reshape((id_sum, 1))#把数据的节点也导入
data2 = np.concatenate((data_id, data), axis=1)
begin_time = time()
root = createKDTree(None, data2)
res_list,_,res_id = findNN(root, data2[50][1:], 2)
end_time = time()
run_time = end_time - begin_time
print("最优距离值为: ",res_list)
print("最优点为: ",res_id)
print("程序运行时间: ",run_time)
3.数据集与代码源文件放在github上
https://github.com/Square-of-W/ML-Code/tree/KDTree
写在后文
本文是在自己理解了kdTree之后回忆复习整理所作,初次接触kd树的可以先阅读以下第一篇原理讲解文章,我当初就是先看这篇文章理解kd树的。
1.kd树算法之详解篇-知乎
2.代码参考
3.参考:《统计学习方法》-李航