目录
一、K-d Tree的发现(Who , Where , When)
对于信息收集而言,最基本的方法是进行5W1H提问,Who,Where,When,What,Why,How。对于计算机相关的知识点亦如是。接下来,我将从这几个方面切入,介绍K-d Tree。
一、K-d Tree的发现(Who , Where , When)
K-d树(K-dimensional tree)是由Jon Louis Bentley在1975年发表的论文《Multidimensional Divide-and-Conquer》中首次提出的数据结构。该论文是在斯坦福大学的计算机科学系发表的。K-d树是一种多维空间中对数据点进行结构化存储和快速搜索的数据结构。它被广泛应用于数据挖掘、模式识别、图形学等领域。因其在高维空间中的高效查询表现而受到广泛关注。
二、什么是K-d Tree(What)
K-d树(K-dimensional tree)是一种多维空间中用于结构化存储和快速搜索数据点的树形数据结构。它通过递归地将空间划分为更小的区域,构建一棵二叉树来表示这些区域,从而实现对数据点的高效组织和检索。
K-d树的核心思想是在每个节点上选择一个维度进行切分,将数据集分割成两个子集,使得每个子集中的数据点都在切分维度上的某一侧。这样,树的每个节点都代表一个数据点,并且以该节点为根节点的子树对应于在切分维度上的数据区间。
通过这种划分方式,K-d树在搜索时能够快速定位到目标点附近的数据点,从而实现高维数据的快速查找。它被广泛应用于诸如数据挖掘、模式识别、图形学等领域,特别是在需要处理大规模高维数据集的情况下,具有很高的效率和实用性。
三、如何构建K-d Tree (How)
首先以一个简单的例子(二维特征)切入,了解K-d树的构建。
在二维平面上有以下六个点:(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)
① 确定分割轴:由于X轴方差较大,因此以X轴为特征进行分割。将上述六个点的点集按照二维平面的第一维(X轴)对数据进行排序,排序结果为:
(2,3),(4,7),(5,4),(7,2),(8,1),(9,6)
② 取得上述点集的中位数处的点(若长度为偶数则取len/2+1)。在上述例子中取到的点为(7,2),在该点处对平面进行划分。划分后结果为:
左支:{(2,3),(4,7),(5,4)} 右支:{(8,1),(9,6)}
③ 将两棵子树的所有点按照Y值进行排序,排序结果为:
左支:{(2,3),(5,4),(4,7)} 右支:{(8,1),(9,6)}
④将左枝和右枝分别取中位数点作为新的节点,按照第一维继续排序,直致每一个子枝上只剩一个点,这个点被称作叶子。
上述过程画成图如下:
注意:每次划分之后排序的情况是按照维度依次类推的,例如在二维平面是按x,y,x,y进行下去,在三维空间就是x,y,z,x,y,z进行下取。
四、K-d Tree的应用 (Why)
最邻近搜索(Nearest Neighbor Search) 1-NN
构建完K-d 树之后,举一个简单的例子来观察K-dTree是如何工作的。
举个例子:根据上面构建的树,查找(1,2)的最近邻。(为了简便计算,此处距离使用欧氏距离平方)
- 计算当前节点(7,2)的距离,为36,并且暂定为(7,2),根据当前分割轴的维度X(1 < 7),选取左支。
- 计算当前节点(5,4)的距离,为24,由于24< 36,暂时定为(5,4),根据当前分割轴维度Y(2< 4),选取左支。
- 计算当前节点(2,3)的距离,为2,由于2 < 24,暂定为(2,3),根据当前分割轴维度(1 < 2),选取左支,而左支为空,回溯上一个节点。
- 计算(1,2)与(5,4)的分割轴{y = 4}的距离平方,如果2小于距离值平方,说明就是最近值。如果大于距离值,说明,还有可能存在值与(1,2)最近,需要往右支检索。
由于2< 4,我们找到了最近邻的值为(2,3),最近距离为。
实际案例应用:预测酒的质量
特征为11维,最后一维是Quality,即酒的质量,是需要预测的y值。测试集只给出11维数据,需要根据训练好的K-d树,来对测试集进行质量的预测。
代码简述:
完整代码实现:
import numpy as np
import pandas as pd
import sys
# node类
class Node:
def __init__(self, point, left=None, right=None):
self.point = point
self.left = left
self.right = right
# 建立kdtree
def BuildKdTree(points, depth=0):
##没东西了,返回空
if len(points) == 0:
return None
##只有一个节,返回他
elif len(points) == 1:
return Node(points2[0])
#选取二分的特征
axis = depth % len(points[0])
#排序
points.sort(key=lambda x: x[axis])
#寻找中位数
median = len(points) // 2
##找到中位数所处的节点,递归处理左子树,右子树
return Node(
point=points[median],
left=BuildKdTree(points[:median], depth + 1),
right=BuildKdTree(points[median + 1:], depth + 1)
)
# 找最近邻
def find_nearest_neighbor(root, query_point):
best = [None, float('inf'), -1, -1] # 添加一个位置用于存储最近邻点的索引和特征索引
def search(node, depth=0):
nonlocal best
if node is None:
return
axis = depth % len(query_point)
dist = np.linalg.norm(np.array(query_point) - np.array(node.point))
if dist < best[1]:
best = [node.point, dist, -1, -1] # 将最近邻点的索引和特征索引一起存储
if query_point[axis] < node.point[axis]:
search(node.left, depth + 1)
else:
search(node.right, depth + 1)
search(root)
return best
# 读入train
if len(sys.argv) != 4:
print("Usage: python nn_kdtree.py [train] [test] [dimension]")
else:
train_file = sys.argv[1]
test_file = sys.argv[2]
dimension = int(sys.argv[3])
points = pd.read_csv(train_file,sep = '\s+')
points_fea = points.iloc[:,0:dimension]
points2 = points_fea.copy()
points_y = points.iloc[:,-1]
points2 = points2.values.tolist()
points_fea = points_fea.values.tolist()
#测试集
query_points = pd.read_csv(test_file , sep = '\s+')
query_points = query_points.iloc[:,0:dimension]
# 构建树
root = BuildKdTree(points2)
for i in range(len(query_points)):
nearest_neighbor = find_nearest_neighbor(root, query_points.iloc[i])
feature_index = points_fea.index(nearest_neighbor[0]) # 根据最近邻点的特征找到对应的特征索引
y_value = points_y.iloc[feature_index] # 使用最近邻点在原始数据中的索引
print("最近邻点特征值:", nearest_neighbor[0])
print("对应的 y 值:", y_value)
注意,此处无法直接在终端运行。需要在控制台中找到对应文件路径,在当前目录下,键入
python nn_kdtree.py [train] [test] [dimension]
得到运行结果