(笔记+代码+习题)统计学习方法第三章 KNN及其kd树实现(Python)

一、KNN总体思想

给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的 k k k个实例,这 k k k个实例的多数属于某个类,就把该输入实例分为这个类。

  • 解决问题:分类问题
  • 输入:实例的特征向量
  • 输出:实例的类别
  • 三个基本要素 k k k值的选择、距离度量、分类决策规则
  • 最近邻算法 k = 1 k=1 k=1

二、KNN详解

1、算法

输入: T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x N , y N ) } T=\{ (x_1,y_1),(x_2,y_2),...,(x_N,y_N)\} T={(x1,y1),(x2,y2),...,(xN,yN)}

输出: y y y

步骤:

  1. 根据指定的距离度量,在训练集 T T T中找出与 x x x最近邻的 k k k个点,涵盖这 k k k个点的 x x x的邻域记作 N k ( x ) N_k(x) Nk(x)

  2. N k ( x ) N_k(x) Nk(x)中根据**分类决策规则(如多数表决)**决定 x x x的类别 y y y

    y = arg ⁡ max ⁡ c j ∑ x i ∈ N k ( x ) I ( y i = c j ) , i = 1 , 2 , … , N , j = 1 , 2 , … , K y=\arg\max_{c_j}\sum_{x_i\in N_k(x)}I(y_i=c_j), i=1,2,\dots,N, j=1,2,\dots,K y=argcjmaxxiNk(x)I(yi=cj),i=1,2,,N,j=1,2,,K

2、距离度量

闵可夫斯基距离:

L p ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ p ) 1 p L_p(x_i, x_j)=\left(\sum_{l=1}^{n}{\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^p}\right)^{\frac{1}{p}} Lp(xi,xj)=(l=1nxi(l)xj(l)p)p1

  • p = 1 p=1 p=1时,曼哈顿距离
  • p = 2 p=2 p=2时,欧式距离
  • p = ∞ p=∞ p=时,切比雪夫距离

注:

范数是对向量或者矩阵的度量,是一个标量,这个里面两个点之间的 L p L_p Lp距离可以认为是两个点坐标差值的 p p p范数。

3、K值选择

  • k k k值较小时,近似误差(approximation error)会减小,只有与输入实例较近(相似的)训练实例才会对预测结果起作用。但缺点时“学习”的估计误差(estimation error)会增大,对近邻实例点非常敏感,如果是噪声,预测就会出错。 k k k值减小,模型越复杂,越可能发生过拟合。
  • k k k值较大时,与前面相反,模型变得简单。
  • 实际应用时, k k k值一般取一个比较小的数值,采用交叉验证法来选取最优的 k k k值。

三、KNN的kd树实现及代码(Python)

KNN主要考虑的问题是如何对训练数据进行快速 k k k近邻搜索:

  • 线性扫描(linear scan)
  • kd tree
  • 其它

注:

kd树是存储 k k k维空间数据的树结构,这里的 k k k k k k近邻法的意义不同。(输入 x x x k k k维)

1、构造KD树

下面是构建最近邻算法,构建k近邻见后面习题3.3.

  • 输入:

    k k k维空间数据集 T = { x 1 , x 2 , . . . , x N } T=\{ x_1,x_2,...,x_N\} T={x1,x2,...,xN},其中 x i = ( x 1 ( 1 ) , ( x 2 ( 2 ) , . . . , ( x k ( k ) ) T , i = 1 , 2 , . . . , N x_i=(x^{(1)}_1,(x^{(2)}_2,...,(x^{(k)}_k)^T,i=1,2,...,N xi=(x1(1),(x2(2),...,(xk(k))T,i=1,2,...,N

  • 输出:kd树

  • 过程:

    • stpe1

      构造根节点,根节点对应于包含 T T T k k k维空间的超矩形区域。

      选择 x ( 1 ) x^{(1)} x(1)为坐标轴,以 T T T中所有实例的 x ( 1 ) x^{(1)} x(1)坐标的中位数为切分点。由根节点生成深度为1的左右子节点,左子节点对应于坐标 x ( 1 ) x^{(1)} x(1)小于切分点区域,右子节点对应于坐标 x ( 1 ) x^{(1)} x(1)大于切分点的子区域

    • step2

      重复:对深度为 j j j的结点,选择 x ( l ) x^{(l)} x(l)为切分点的坐标轴, l = j ( m o d   k ) + 1 l=j(mod\ k)+1 l=j(mod k)+1,以该结点的区域中所有实例的 x ( l ) x^{(l)} x(l)坐标的中位数为切分点,将该节点对应的超矩形区域划分为两个子区域。

    • step3

      直到两个子区域没有实例存在时停止

  • 代码实现:

    # kd-tree每个结点中主要包含的数据结构如下
    class KdNode(object):
        def __init__(self, dom_elt, split, left, right):
            self.dom_elt = dom_elt  # k维向量节点(k维空间中的一个样本点)
            self.split = split  # 整数(进行分割维度的序号)
            self.left = left  # 该结点分割超平面左子空间构成的kd-tree
            self.right = right  # 该结点分割超平面右子空间构成的kd-tree
    
    class KdTree(object):
        def __init__(self, data):
            k = len(data[0])  # 数据维度
    
            def CreateNode(split, data_set):  # 按第split维划分数据集exset创建KdNode
                if not data_set:  # 数据集为空
                    return None
                # key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较
                # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为需要获取的数据在对象中的序号
                #data_set.sort(key=itemgetter(split)) # 按要进行分割的那一维数据排序
                data_set.sort(key=lambda x: x[split])
                split_pos = len(data_set) // 2  # //为Python中的整数除法
                median = data_set[split_pos]  # 中位数分割点
                split_next = (split + 1) % k  # cycle coordinates
    
                # 递归的创建kd树
                return KdNode(
                    median,
                    split,
                    CreateNode(split_next, data_set[:split_pos]),  # 创建左子树
                    CreateNode(split_next, data_set[split_pos + 1:]))  # 创建右子树
    
            self.root = CreateNode(0, data)  # 从第0维分量开始构建kd树,返回根节点
    
    # KDTree的前序遍历
    def preorder(root):
        print(root.dom_elt)
        if root.left:  # 节点不为空
            preorder(root.left)
        if root.right:
            preorder(root.right)
    
  • 测试代码:

    data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    kd = KdTree(data)
    preorder(kd.root)
    

    输出结果:

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3D5LkZnb-1636978032271)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/69518ba5-3e95-4397-86da-aa2f4c251a1e/Untitled.png)]

2、搜索KD树

  • 输入:已构造的kd树,目标点 x x x

  • 输出: x x x的最近邻

  • 过程

    • step1——在 k d kd kd树中找出包含目标点 x x x的叶结点

      从根结点出发,递归地向下访问 k d kd kd树。若目标点 x x x当前维的坐标小于切分点的坐标,则移动到左子节点,否则移动到右子节点。直到子节点为叶结点为止。

    • step2——以此节点为“当前最近点”

    • step3——递归地向上回退,在每个结点进行一下操作

      • a、如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”

      • b、当前最近点一定存在于该节点一个子节点对应的区域。检查该子节点的父节点的另一子节点对应的区域是否有更近的点。具体地,检查另一子节点对应地区域是否于以目标点为球心、以目标点与“当前最近点”间地距离为半径的超球体相交。

        如果相交,可能在另一个子节点对应的区域内存在距目标点更近的点,移动到另一个子节点。接着,递归地进行最近邻搜索。

        如果不相交,向上回退

    • step4——当回退到根节点时,搜索结束。最后的“当前最近点”即为 x x x的最近邻点。

    如果实例点是随机分布的, k d kd kd树搜索的平均计算复杂度是 O ( l o g N ) O(log N) O(logN)。当空间维数接近训练实例数时,它的效率会迅速下降。

    • 代码实现

      # 对构建好的kd树进行搜索,寻找与目标点最近的样本点:
      from math import sqrt
      from collections import namedtuple
      
      # 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
      result = namedtuple("Result_tuple",
                          "nearest_point  nearest_dist  nodes_visited")
      
      def find_nearest(tree, point):
          k = len(point)  # 数据维度
      
          def travel(kd_node, target, max_dist):
              if kd_node is None:
                  return result([0] * k, float("inf"),
                                0)  # python中用float("inf")和float("-inf")表示正负无穷
      
              nodes_visited = 1
      
              s = kd_node.split  # 进行分割的维度
              pivot = kd_node.dom_elt  # 进行分割的“轴”
      
              if target[s] <= pivot[s]:  # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
                  nearer_node = kd_node.left  # 下一个访问节点为左子树根节点
                  further_node = kd_node.right  # 同时记录下右子树
              else:  # 目标离右子树更近
                  nearer_node = kd_node.right  # 下一个访问节点为右子树根节点
                  further_node = kd_node.left
      
              temp1 = travel(nearer_node, target, max_dist)  # 进行遍历找到包含目标点的区域
      
              nearest = temp1.nearest_point  # 以此叶结点作为“当前最近点”
              dist = temp1.nearest_dist  # 更新最近距离
      
              nodes_visited += temp1.nodes_visited
      
              if dist < max_dist:
                  max_dist = dist  # 最近点将在以目标点为球心,max_dist为半径的超球体内
      
              temp_dist = abs(pivot[s] - target[s])  # 第s维上目标点与分割超平面的距离
              if max_dist < temp_dist:  # 判断超球体是否与超平面相交
                  return result(nearest, dist, nodes_visited)  # 不相交则可以直接返回,不用继续判断
      
              #----------------------------------------------------------------------
              # 计算目标点与分割点的欧氏距离
              temp_dist = sqrt(sum((p1 - p2)**2 for p1, p2 in zip(pivot, target)))
      
              if temp_dist < dist:  # 如果“更近”
                  nearest = pivot  # 更新最近点
                  dist = temp_dist  # 更新最近距离
                  max_dist = dist  # 更新超球体半径
      
              # 检查另一个子结点对应的区域是否有更近的点
              temp2 = travel(further_node, target, max_dist)
      
              nodes_visited += temp2.nodes_visited
              if temp2.nearest_dist < dist:  # 如果另一个子结点内存在更近距离
                  nearest = temp2.nearest_point  # 更新最近点
                  dist = temp2.nearest_dist  # 更新最近距离
      
              return result(nearest, dist, nodes_visited)
      
          return travel(tree.root, point, float("inf"))  # 从根节点开始递归
      
    • 测试代码:

      from time import clock
      from random import random
      
      # 产生一个k维随机向量,每维分量值在0~1之间
      def random_point(k):
          return [random() for _ in range(k)]
       
      # 产生n个k维随机向量 
      def random_points(k, n):
          return [random_point(k) for _ in range(n)]
      
      ret = find_nearest(kd, [3,4.5])
      print (ret)
      

      结果:

      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h607Zces-1636978032277)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/0f6c1d3a-0dcb-42f7-931c-cc6515a419cb/Untitled.png)]

四、习题

习题3.1

参照图3.1,在二维空间中给出实例点,画出 k k k为1和2时的 k k k近邻法构成的空间划分,并对其进行比较,体会 k k k值选择与模型复杂度及预测准确率的关系。

%matplotlib inline
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

data = np.array([[5, 12, 1], [6, 21, 0], [14, 5, 0], [16, 10, 0], [13, 19, 0],
                 [13, 32, 1], [17, 27, 1], [18, 24, 1], [20, 20,
                                                         0], [23, 14, 1],
                 [23, 25, 1], [23, 31, 1], [26, 8, 0], [30, 17, 1],
                 [30, 26, 1], [34, 8, 0], [34, 19, 1], [37, 28, 1]])
X_train = data[:, 0:2]
y_train = data[:, 2]

models = (KNeighborsClassifier(n_neighbors=1, n_jobs=-1),
          KNeighborsClassifier(n_neighbors=2, n_jobs=-1))
models = (clf.fit(X_train, y_train) for clf in models)

titles = ('K Neighbors with k=1', 'K Neighbors with k=2')

fig = plt.figure(figsize=(15, 5))
plt.subplots_adjust(wspace=0.4, hspace=0.4)

X0, X1 = X_train[:, 0], X_train[:, 1]

x_min, x_max = X0.min() - 1, X0.max() + 1
y_min, y_max = X1.min() - 1, X1.max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.2),
                     np.arange(y_min, y_max, 0.2))

for clf, title, ax in zip(models, titles, fig.subplots(1, 2).flatten()):
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    colors = ('red', 'green', 'lightgreen', 'gray', 'cyan')
    cmap = ListedColormap(colors[:len(np.unique(Z))])
    ax.contourf(xx, yy, Z, cmap=cmap, alpha=0.5)
    ax.scatter(X0, X1, c=y_train, s=50, edgecolors='k', cmap=cmap, alpha=0.5)
    ax.set_title(title)

plt.show()

结果:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B92bciBF-1636978032280)(https://s3-us-west-2.amazonaws.com/secure.notion-static.com/4b6cd2bb-49fb-47b9-84b2-5d737229350b/Untitled.png)]

习题3.3(重要!!!)

参照算法3.3,写出输出为 x x x k k k近邻的算法。

算法:用kd树的k近邻搜索
输入:已构造的kd树;目标点x;
输出:x的最近邻

  1. 在kd树中找出包含目标点x的叶结点:从根结点出发,递归地向下访问树。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点,直到子结点为叶结点为止;
  2. 如果“当前k近邻点集”元素数量小于k或者叶节点距离小于“当前k近邻点集”中最远点距离,那么将叶节点插入“当前k近邻点集”;
  3. 递归地向上回退,在每个结点进行以下操作:
    (a)如果“当前k近邻点集”元素数量小于k或者当前节点距离小于“当前k近邻点集”中最远点距离,那么将该节点插入“当前k近邻点集”。
    (b)检查另一子结点对应的区域是否与以目标点为球心、以目标点与于“当前k近邻点集”中最远点间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点,接着,递归地进行最近邻搜索;如果不相交,向上回退; 4. 当回退到根结点时,搜索结束,最后的“当前k近邻点集”即为x的最近邻点。

代码实现:

# 构建kd树,搜索待预测点所属区域
from collections import namedtuple
import numpy as np

# 建立节点类
class Node(namedtuple("Node", "location left_child right_child")):
    def __repr__(self):
        return str(tuple(self))

# kd tree类
class KdTree():
    def __init__(self, k=1):
        self.k = k
        self.kdtree = None

    # 构建kd tree
    def _fit(self, X, depth=0):
        try:
            k = self.k
        except IndexError as e:
            return None
        # 这里可以展开,通过方差选择axis
        axis = depth % k
        X = X[X[:, axis].argsort()]
        median = X.shape[0] // 2
        try:
            X[median]
        except IndexError:
            return None
        return Node(location=X[median],
                    left_child=self._fit(X[:median], depth + 1),
                    right_child=self._fit(X[median + 1:], depth + 1))

    def _search(self, point, tree=None, depth=0, best=None):
        if tree is None:
            return best
        k = self.k
        # 更新 branch
        if point[0][depth % k] < tree.location[depth % k]:
            next_branch = tree.left_child
        else:
            next_branch = tree.right_child
        if not next_branch is None:
            best = next_branch.location
        return self._search(point,
                            tree=next_branch,
                            depth=depth + 1,
                            best=best)

    def fit(self, X):
        self.kdtree = self._fit(X)
        return self.kdtree

    def predict(self, X):
        res = self._search(X, self.kdtree)
        return res

KNN = KdTree()
X_train = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
KNN.fit(X_train)
X_new = np.array([[3, 4.5]])
res = KNN.predict(X_new)

x1 = res[0]
x2 = res[1]

print("x点的最近邻点是({0}, {1})".format(x1, x2))

参考代码:

fengdu78/lihang-code

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值