《统计学习方法》第三章——k近邻法及Python实现

一、概述

本文是《统计学习方法》的第三章,包含k近邻算法的原理与python实现。希望自己能坚持下去,完成整本书的学习

二、k近邻算法

k近邻是一种基本的分类与回归方法。本文只讨论分类问题中的k近邻算法。k近邻算法的输入为实例的特征向量,对于输入的实例,可以取多类。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,K近邻法不具有显示的学习过程。k近邻实际上是利用训练集对特征空间进行划分,并作为其分类的“模型”。k值的选择、距离度量及分类决策规则是k近邻法的三个基本要素。

三、k值的选择与分析

k值的选择会对k近邻算法的结果产生重大影响。
k较小,容易被噪声影响,发生过拟合。结果受临近的几个点的影响会很大,估计误差会增大。
k较大,学习的近似误差会增大,与输入实例距离较远的实例也会对预测起作用,使预测发生错误。k较大相当于模型变得简单。

四、k近邻法的实现:KD树

k近邻算法最简单的实现方法是线性扫描,这时要计算输入实例与每一个训练实例的距离。当训练集很大的时候,计算非常耗时,这种方法是不可行的。
为了提高k近邻算法搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。下面介绍这些方法中的一种,kd树。
对数据集T中的子集S初始化S=T,取当前节点node=root取维数的序数i=0,对S递归执行:
找出S的第i维的中位数对应的点,通过该点,且垂直于第i维坐标轴做一个超平面。该点加入node的子节点。该超平面将空间分为两个部分,对这两个部分分别重复此操作(S=S’,++i,node=current),直到不可再分。
下面是python代码实现:

T = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
class node:
    def __init__(self,point,split):
        self.left=None
        self.right=None
        self.parent=None
        self.point=point
        self.split=split
        pass
    def set_left(self,node):
        if node==None:
            pass
        self.left=node
        node.parent=self
    def set_right(self,node):
        if node==None:
            pass
        self.right=node
        node.parent=self
def median(data):
    m=len(data)//2
    return data[m],m
def build_kdtree(data, d):
    data = sorted(data, key=lambda x: x[d])
    p, m = median(data)
    tree = node(p, d)

    del data[m]

    if m > 0: tree.set_left(build_kdtree(data[:m], not d))
    if len(data) > 1: tree.set_right(build_kdtree(data[m:], not d))
    return tree
def distance(a, b):
    print (a, b)
    return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5
def search_kdtree(tree, d, target, root):
    if target[d] < tree.point[d]:
        if tree.left != None:
            return search_kdtree(tree.left, not d, target, root)
    else:
        if tree.right != None:
            return search_kdtree(tree.right, not d, target, root)

    def update_best(t, best):
        if t == None: return
        t = t.point
        d = distance(t, target)
        if d < best[1]:
            best[1] = d
            best[0] = t
        return
    best = [tree.point, distance(tree.point, target)]
    while (tree.parent != None and tree != root):
        split = tree.parent.split
        if(best[1] > abs(target[split] - tree.parent.point[split])):
            update_best(tree.parent, best)
            tempBest = None
            if(tree.point[split] < tree.parent.point[split]):
                if(tree.parent.right != None):
                    tempBest = search_kdtree(tree.parent.right, tree.parent.right.split, target, tree.parent.right)
            else:
                if(tree.parent.left != None):
                    tempBest = search_kdtree(tree.parent.left, tree.parent.left.split, target, tree.parent.left)
            if(tempBest != None and tempBest[1] < best[1]):
                best = tempBest
        tree = tree.parent
    return best


kd_tree = build_kdtree(T, 0)
print (search_kdtree(kd_tree, 0, [9, 4], kd_tree))

搜索是一个递归的过程。先直接到叶节点,然后找到目标点的插入位置,然后往上走,逐步用自己到目标点的距离画个超球体,用超球体圈住的点来更新最近邻(或k最近邻)。
输出结果如下:

[8, 1] [9, 4]
[9, 6] [9, 4]
[[9, 6], 2.0]

图中仅用了两次搜索,便查出了距离最近的点,因此可以看出kd树是一个性能优越的数据结构。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值