统计学习方法 | knn算法 | python实现

K近邻算法没有显示的学习过程,由三个基本要素——距离度量,k值的选择和分类决策规则决定。

距离选择可以是欧式距离,Lp距离或Minkowski距离

k值的选择过大会使模型变得简单;过小会导致过拟合,通常取一个比较小的数值采用交叉验证法来取最优。

分类决策规则往往是多数表决(等价于经验风险最小化)。

kd树最近邻搜索

建树过程比较简单,详见文末代码,这里直接从搜索开始。

在这里插入图片描述
在这里插入图片描述

    # 用kd树的最近邻搜索算法,树已构造好
    def search(self, node, aim, depth=0):
        if node is not None:
            n = len(aim)  # aim是目标点,计算维度
            axis = depth % n
            if aim[axis] < node.data[axis]:
                self.search(node.lchild, aim, depth+1)
            else:
                self.search(node.rchild, aim, depth+1)

            dis = self.dist(aim, node.data)  # 欧式距离
            if self.nearest is None or self.nearestDis > dis:
                self.nearest = node.data   # 保存当前最近点
                self.nearestDis = dis   # 当前最近距离

            # 算法3.3(3)(b) 判断是否需要去另一子结点搜索
            if abs(node.data[axis] - aim[axis]) <= self.nearestDis:
                if aim[axis] < node.data[axis]:
                    self.search(node.rchild, aim, depth + 1)
                else:
                    self.search(node.lchild, aim, depth + 1)

其中nearest和nearestDis在初始化类时定义

    def __init__(self):
        self.kdTree = None
        self.nearest = None
        self.nearestDis = 0

算法的实现学习了这位大佬:k近邻 kd树的python实现

kd树实现k近邻算法

在上述算法的基础上进行改进:用列表来保存k个最近的点,初始化函数变为

    def __init__(self):
        self.kdTree = None
        self.nearest = []
        self.nearestDis = []

算法框架不变,先找到包含目标点的区域然后开始回退。

当列表中的点不足k个时,将当前结点加入nearest列表,同时将当前结点与目标点的距离加入nearestDis列表;另外,当前结点与目标点距离小于nearestDis列表中的最大距离时,对nearest和nearestDis列表进行更新,替换此最远点和对应的距离。

检查是否需要去当前结点的另一子结点搜索,只需将之前的判断半径变为nearest中最远点到目标点的距离就可以。

# 用kd树的k近邻搜索算法,k缺省为1
def search(self, node, aim, k=1, depth=0):
    if node is not None:
        n = len(aim) 
        axis = depth % n
        if aim[axis] < node.data[axis]:
            self.search(node.lchild, aim, k, depth+1)
        else:
            self.search(node.rchild, aim, k, depth+1)

        dis = self.dist(aim, node.data)  # 欧式距离
        if len(self.nearest) < k:  # 若不够k个,加入k近邻列表
            self.nearest.append(node.data)
            self.nearestDis.append(dis)
        elif max(self.nearestDis) > dis:  # 当小于k近邻列表中最大值时替换此最大值
            maxIndex = self.nearestDis.index(max(self.nearestDis))
            self.nearest[maxIndex] = node.data
            self.nearestDis[maxIndex] = dis

        # 判断是否需要去另一子结点搜索
        if abs(node.data[axis] - aim[axis]) <= max(self.nearestDis):  # 搜索半径
            if aim[axis] < node.data[axis]:
                self.search(node.rchild, aim, k, depth + 1)
            else:
                self.search(node.lchild, aim, k, depth + 1)

k缺省为1;当k大于数据集中所有点时,不报错,输出所有点。

对书上实例运行如下:

data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
x = [5, 3]
kdtree = KdTree()  # 定义一个kd树实例
tree = kdtree.create(data)  # 构造树
# kdtree.preOrder(tree)  
kdtree.search(tree, x, 3)  # 搜索
print(kdtree.nearest)
print(kdtree.nearestDis)
结果:
[[2, 3], [5, 4], [7, 2]]
[3.0, 1.0, 2.23606797749979]

代码下载

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值