统计学习方法-k近邻法-python代码实现

knn

本文主要结合书中例3-2的数据,构建kd树,并预测点[6.8,0],[2,4.5]的最邻近点。

代码实现:

import numpy as np
class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild


class KdTree:
    def __init__(self):
        self.kdTree = None

    def create(self, dataSet, depth):   #创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出样本行,列
            midIndex = int(m / 2) #中间数的索引位置
            axis = depth % n    #判断以哪个轴划分数据
            sortedDataSet = self.sort(dataSet, axis) #进行排序
            node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None

    def sort(self, dataSet, axis):  #采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]    #由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        #print(sortDataSet)
        return sortDataSet

    def preOrder(self, node):#对KD树进行先序遍历
        if node != None:
            print("节点{}".format( node.data))
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    def travel(self,node,x ,depth=0):  # 递归搜索
        if node != None:  # 递归终止条件
            print("访问{}".format(node.data))
            n = len(x)  # 特征数
            axis = depth % n  # 计算轴
            if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找
                self.travel(node.lchild,x, depth + 1)
            else:
                self.travel(node.rchild,x, depth + 1)
                
            # 以下是递归完毕后,往父结点方向回朔
            distNodeAndX = self.dist(x, node.data)  # 目标和节点的距离判断
            if (self.nearestPoint == None):
                print("叶子节点{}确定为初始最近点".format(node.data))# 确定当前点,更新最近的点和最近的值,因为是叶子节点,所以无需回溯
                self.nearestPoint = node.data
                self.nearestValue = distNodeAndX
                print("节点坐标{},深度{}, 最近距离值{:.2f},节点轴的值{},目标点轴的值{}".format(node.data, depth, self.nearestValue,
                                                                         node.data[axis],
                                                                         x[axis]))
            elif (self.nearestValue > distNodeAndX):
                self.nearestPoint = node.data
                self.nearestValue = distNodeAndX
                print("节点坐标{},深度{}, 最近距离值{:.2f},节点轴的值{},目标点轴的值{}".format(node.data, depth, self.nearestValue,
                                                                         node.data[axis],
                                                                         x[axis]))
                if x[axis] < node.data[axis]:#对节点的子节点的父节点的另一子节点进行检查
                    print("回溯——进入node{}的右子节点".format(node.data))
                    self.travel(node.rchild,x, depth + 1)
                else:
                    print("回溯——进入node{}的左子节点".format(node.data))
                    self.travel(node.lchild, x,depth + 1)



    def search(self,tree,x):
        self.nearestPoint=None
        self.nearestValue=0
        self.travel(tree,x)
        return self.nearestPoint, self.nearestValue

    def dist(self, x1, x2): #欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

dataSet = [[2, 3],
           [5, 4],
           [9, 6],
           [4, 7],
           [8, 1],
           [7, 2]]
x = [2,4.5]
kdtree = KdTree()
print("\n构造kd树\n")
tree = kdtree.create(dataSet, 0)
print("\n先序遍历kd树")
kdtree.preOrder(tree)
print("\nkd搜索")
print(kdtree.search(tree,x))

结果展示:

构造kd树

先序遍历kd树

节点[7, 2]
节点[5, 4]
节点[2, 3]
节点[4, 7]
节点[9, 6]
节点[8, 1]

kd搜索

访问[7, 2]
访问[5, 4]
访问[4, 7]
叶子节点[4, 7]确定为初始最近点
节点坐标[4, 7],深度2, 最近距离值3.20,节点轴的值4,目标点轴的值2
节点坐标[5, 4],深度1, 最近距离值3.04,节点轴的值4,目标点轴的值4.5
回溯——进入node[5, 4]的左子节点
访问[2, 3]
节点坐标[2, 3],深度2, 最近距离值1.50,节点轴的值2,目标点轴的值2
回溯——进入node[2, 3]的左子节点
([2, 3], 1.5)

输入(6.8,0)

构造kd树
先序遍历kd树
节点[7, 2]
节点[5, 4]
节点[2, 3]
节点[4, 7]
节点[9, 6]
节点[8, 1]

kd搜索
访问[7, 2]
访问[5, 4]
访问[2, 3]
叶子节点[2, 3]确定为初始最近点
节点坐标[2, 3],深度2, 最近距离值5.66,节点轴的值2,目标点轴的值6.8
节点坐标[5, 4],深度1, 最近距离值4.39,节点轴的值4,目标点轴的值0
回溯——进入node[5, 4]的右子节点
访问[4, 7]
节点坐标[7, 2],深度0, 最近距离值2.01,节点轴的值7,目标点轴的值6.8
回溯——进入node[7, 2]的右子节点
访问[9, 6]
访问[8, 1]
节点坐标[8, 1],深度2, 最近距离值1.56,节点轴的值8,目标点轴的值6.8
回溯——进入node[8, 1]的右子节点
([8, 1], 1.562049935181331)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值