knn
本文主要结合书中例3-2的数据,构建kd树,并预测点[6.8,0],[2,4.5]的最邻近点。
代码实现:
import numpy as np
class Node:
def __init__(self, data, lchild = None, rchild = None):
self.data = data
self.lchild = lchild
self.rchild = rchild
class KdTree:
def __init__(self):
self.kdTree = None
def create(self, dataSet, depth): #创建kd树,返回根结点
if (len(dataSet) > 0):
m, n = np.shape(dataSet) #求出样本行,列
midIndex = int(m / 2) #中间数的索引位置
axis = depth % n #判断以哪个轴划分数据
sortedDataSet = self.sort(dataSet, axis) #进行排序
node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
# print sortedDataSet[midIndex]
leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
rightDataSet = sortedDataSet[midIndex+1 :]
node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
node.rchild = self.create(rightDataSet, depth+1)
return node
else:
return None
def sort(self, dataSet, axis): #采用冒泡排序,利用aixs作为轴进行划分
sortDataSet = dataSet[:] #由于不能破坏原样本,此处建立一个副本
m, n = np.shape(sortDataSet)
for i in range(m):
for j in range(0, m - i - 1):
if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
temp = sortDataSet[j]
sortDataSet[j] = sortDataSet[j+1]
sortDataSet[j+1] = temp
#print(sortDataSet)
return sortDataSet
def preOrder(self, node):#对KD树进行先序遍历
if node != None:
print("节点{}".format( node.data))
self.preOrder(node.lchild)
self.preOrder(node.rchild)
def travel(self,node,x ,depth=0): # 递归搜索
if node != None: # 递归终止条件
print("访问{}".format(node.data))
n = len(x) # 特征数
axis = depth % n # 计算轴
if x[axis] < node.data[axis]: # 如果数据小于结点,则往左结点找
self.travel(node.lchild,x, depth + 1)
else:
self.travel(node.rchild,x, depth + 1)
# 以下是递归完毕后,往父结点方向回朔
distNodeAndX = self.dist(x, node.data) # 目标和节点的距离判断
if (self.nearestPoint == None):
print("叶子节点{}确定为初始最近点".format(node.data))# 确定当前点,更新最近的点和最近的值,因为是叶子节点,所以无需回溯
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
print("节点坐标{},深度{}, 最近距离值{:.2f},节点轴的值{},目标点轴的值{}".format(node.data, depth, self.nearestValue,
node.data[axis],
x[axis]))
elif (self.nearestValue > distNodeAndX):
self.nearestPoint = node.data
self.nearestValue = distNodeAndX
print("节点坐标{},深度{}, 最近距离值{:.2f},节点轴的值{},目标点轴的值{}".format(node.data, depth, self.nearestValue,
node.data[axis],
x[axis]))
if x[axis] < node.data[axis]:#对节点的子节点的父节点的另一子节点进行检查
print("回溯——进入node{}的右子节点".format(node.data))
self.travel(node.rchild,x, depth + 1)
else:
print("回溯——进入node{}的左子节点".format(node.data))
self.travel(node.lchild, x,depth + 1)
def search(self,tree,x):
self.nearestPoint=None
self.nearestValue=0
self.travel(tree,x)
return self.nearestPoint, self.nearestValue
def dist(self, x1, x2): #欧式距离的计算
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
dataSet = [[2, 3],
[5, 4],
[9, 6],
[4, 7],
[8, 1],
[7, 2]]
x = [2,4.5]
kdtree = KdTree()
print("\n构造kd树\n")
tree = kdtree.create(dataSet, 0)
print("\n先序遍历kd树")
kdtree.preOrder(tree)
print("\nkd搜索")
print(kdtree.search(tree,x))
结果展示:
构造kd树
先序遍历kd树
节点[7, 2]
节点[5, 4]
节点[2, 3]
节点[4, 7]
节点[9, 6]
节点[8, 1]
kd搜索
访问[7, 2]
访问[5, 4]
访问[4, 7]
叶子节点[4, 7]确定为初始最近点
节点坐标[4, 7],深度2, 最近距离值3.20,节点轴的值4,目标点轴的值2
节点坐标[5, 4],深度1, 最近距离值3.04,节点轴的值4,目标点轴的值4.5
回溯——进入node[5, 4]的左子节点
访问[2, 3]
节点坐标[2, 3],深度2, 最近距离值1.50,节点轴的值2,目标点轴的值2
回溯——进入node[2, 3]的左子节点
([2, 3], 1.5)
输入(6.8,0)
构造kd树
先序遍历kd树
节点[7, 2]
节点[5, 4]
节点[2, 3]
节点[4, 7]
节点[9, 6]
节点[8, 1]
kd搜索
访问[7, 2]
访问[5, 4]
访问[2, 3]
叶子节点[2, 3]确定为初始最近点
节点坐标[2, 3],深度2, 最近距离值5.66,节点轴的值2,目标点轴的值6.8
节点坐标[5, 4],深度1, 最近距离值4.39,节点轴的值4,目标点轴的值0
回溯——进入node[5, 4]的右子节点
访问[4, 7]
节点坐标[7, 2],深度0, 最近距离值2.01,节点轴的值7,目标点轴的值6.8
回溯——进入node[7, 2]的右子节点
访问[9, 6]
访问[8, 1]
节点坐标[8, 1],深度2, 最近距离值1.56,节点轴的值8,目标点轴的值6.8
回溯——进入node[8, 1]的右子节点
([8, 1], 1.562049935181331)