统计学习方法第3章 k近邻法 Python实现

3.1 k k k 近邻法介绍

k k k 近邻法是一种基本分类与回归方法. k k k 近邻法的输入为实例的特征向量,对应于特征空间的点;输出为实例的类别,可以取多类. 统计学习方法一书只讨论分类问题中的 k k k 近邻法.

3.1.1 k k k 近邻算法

在这里插入图片描述
k k k 近邻法的特殊情况是 k = 1 k=1 k=1的情形,称为最近邻算法,对于输入的实例点(特征向量) x x x,最近邻法将训练数据集中于 x x x 最邻近点的类作为 x x x 的类.

3.1.2 k k k 近邻模型

k k k 近邻法使用的模型实际上对应于特征空间的划分,模型由三个基本要素决定:距离度量、 k k k 值的选择和分类决策规则,接下来会详细展开.
在这里插入图片描述
在这里插入图片描述
(1) 距离度量

特征空间中两个实例点的距离是两个实例点相似程度的反映, k k k 近邻模型的特征空间一般是 n n n 维实数向量Rn,使用的距离是欧式距离,但也可以是其他距离. L 2 L_2 L2 距离的公式如下: L 2 ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ ) 1 2 L_2(x_i,x_j)=\displaystyle(\sum_{l=1}^{n}\mid x_i^{(l)}-x_j^{(l)}\mid)^\frac{1}{2} L2(xi,xj)=(l=1nxi(l)xj(l))21

(2) k k k 值的选择

k k k 值的选择会对 k k k 近邻法的结果产生重大影响. 如果选择较小的 k k k 值,就相当于用较小的邻域中的训练实例进行预测,"学习"的近似误差会减小,只有与输入实例较近的 (相似的) 训练实例才会对预测结果起作用,但估计误差会增大,预测结果会对近邻的实例点非常敏感;如果选择较大的 k k k 值,则两种误差的结果正相反;如果 k = N k=N k=N,那么无论输入实例是什么,都将简单地预测它属于在训练实例中最多的类.

在应用中, k k k 值一般取一个比较小的数值,通常采用交叉验证法来选取最优的值.

(3) 分类决策规则

k k k 近邻法中的分类决策规则往往是多数表决,即由输入实例的 k k k 个邻近的训练实例中的多数类决定输入实例的类. 在这里插入图片描述

3.1.3 k k k 近邻法的实现: k d kd kd

实现 k k k 近邻法时,主要考虑的问题是如何对训练数据进行快速 k k k 近邻搜索,这点在特征空间的维数大及训练数据容量大时尤其必要. 为了提高 k k k 近邻搜索的效率,可以考虑使用特殊的结构存储训练数据, k d kd kd 树方法就是其中一种.

(1) 构造 k d kd kd

k d kd kd 树是一种对 k k k 维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,它是二叉树,表示对 k k k 维空间的一个划分. 构造 k d kd kd 树相当于不断地用垂直于坐标轴的超平面将 k k k 维空间划分,构成一系列的 k k k 维超矩形区域. 相信直接用算法解释会更加直观,此处列出书中定义的构造平衡 k d kd kd 树算法在这里插入图片描述
(2) 搜索 k d kd kd

所谓搜索 k d kd kd 树就是利用 k d kd kd 树进行 k k k 近邻搜索,这样就可以省去对大部分数据点的搜索,从而减少搜索的计算量,接下来就以最近邻为例加以叙述并列出算法.

给定一个目标点,搜索其最近邻. 首先找到包含目标点的叶结点;然后从该叶结点出发,依次回退到父结点;不断查找与目标点最邻近的结点,当确定不可能存在更近的结点时终止. 这里停止搜索的条件是:如果父结点的另一子结点的超矩形区域与超球体不相交,或不存在比当前最近点最近的点,则停止搜索. 最后列出书中算法:
在这里插入图片描述

3.2 k k k 近邻法 Python 实现

之前说过,文章的重点是放在算法的编程语言实现上,所以接下来是重头戏,话不多说,直接上代码,我尽量多做注释使代码简单易懂.

3.2.1 例3.1–不同 p p p 值下寻找目标点的最近邻点

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np

def compute_distance(point1,point2,p):
    
    #point1与point2的维度一致,p是大于0的整数
    dist = 0
    point_shape = point2.shape
    for i in range(point_shape[0]):
        dist +=(abs(point1[i]-point2[i]))**p
    dist = dist**(1/p)
    return dist

class Solution(object):
    
    def find_nearest(self,vector,point,p):
        
        '''
        vector中的元素与point长度一致
        input:数据集vector,确定距离种类的p值与目标点point
        return:目标点的最近邻点
        '''
        vector.remove(point)
        vector = np.array(vector)
        point = np.array(point)
        distance = compute_distance(vector[0],point,p)
        print("点 %s 与 目标点 %s 之间的Lp距离为:%s" % (str(vector[0]),str(point),str(distance)))
        index = [0]
        for i in range(1,len(vector)):
            tmp = compute_distance(vector[i],point,p)
            print("点 %s 与 目标点 %s 之间的Lp距离为:%s" % (str(vector[i]),str(point),str(tmp)))
            if tmp <= distance:
                distance = tmp
                index.append(i)
        print("目标点 %s 的最近邻点为:%s" % (str(point),str(vector[index[-1]])))
        return distance,vector[index[-1]]

if __name__ == '__main__':
    x = [[1,1],[5,1],[4,4]]
    point = [1,1]
    p = 1
    solution = Solution()
    distance,goalpoint = solution.find_nearest(x,point,p)



p=1时的结果如下:

在这里插入图片描述

3.2.2 例3.2–构造平衡 k d kd kd

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np

class QuickSort(object):
    # 首先在构造kd树的时需要寻找中位数,因此用快速排序来获取一个list中的中位数
    def __init__(self, low, high, array):
        self._array = array
        self._low = low
        self._high = high
        self._medium = (low+high+1)//2 # python3中的整除

    def get_medium_num(self):
        return self.quick_sort_for_medium(self._low, self._high, 
                                          self._medium, self._array)

    def quick_sort_for_medium(self, low, high, medium, array): #用快速排序来获取中位数
        if high == low:
            return array[low] # find medium
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            if index == medium:
                return partition
            if index > medium:
                return self.quick_sort_for_medium(low, index-1, medium, array)
            else:
                return self.quick_sort_for_medium(index+1, high, medium, array)

    def quick_sort(self, low, high, array):  #正常的快排
        if high > low:
            index, partition = self.sort_partition(low, high, array); 
            #print array[low:index], partition, array[index+1:high+1]
            self.quick_sort(low, index-1, array)
            self.quick_sort(index+1, high, array)

    def sort_partition(self, low, high, array): # 用第一个数将数组里面的数分成两部分
        index_i = low
        index_j = high
        partition = array[low]
        while index_i < index_j:
            while (index_i < index_j) and (array[index_j] >= partition):
                index_j -= 1
            if index_i < index_j:
                array[index_i] = array[index_j]
                index_i += 1
            while (index_i < index_j) and (array[index_i] < partition):
                index_i += 1
            if index_i < index_j:
                array[index_j] = array[index_i]
                index_j -= 1
        array[index_i] = partition
        return index_i, partition


class KDTree(object):

    def __init__(self, input_x):
        self._input_x = np.array(input_x)
        (data_num, axes_num) = np.shape(self._input_x)
        self._data_num = data_num
        self._axes_num = axes_num
        self._nearest = None  #用来存储最近的节点
        return

    def construct_kd_tree(self):
        return self._construct_kd_tree(0, 0, self._input_x)

    def _construct_kd_tree(self, depth, axes, data):
        if not data.any():
            return None
        axes_data = data[:, axes].copy()
        qs = QuickSort(0, axes_data.shape[0]-1, axes_data)
        medium = qs.get_medium_num() #找到轴的中位数

        data_list = []
        left_data = []
        right_data = []
        data_range = range(np.shape(data)[0])
        for i in data_range:   # 跟中位数相比较
            if data[i][axes] == medium:  #相等
                data_list.append(data[i])
            elif data[i][axes] < medium: 
                left_data.append(data[i])
            else:
                right_data.append(data[i])

        left_data = np.array(left_data)
        right_data = np.array(right_data)
        left = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, left_data)
        right = self._construct_kd_tree(depth+1, (axes+1)% self._axes_num, right_data)
        #[树的深度,轴,中位数,该节点的数据,左子树,右子树]
        root = [depth, axes, medium, data_list, left, right] 
        return root

    def print_kd_tree(self, root): #打印kd树
        if root:
            [depth, axes, medium, data_list, left, right] = root
            print('{} {}'.format('    '*depth, data_list[0]))
            if root[4]:
                self.print_kd_tree(root[4])
            if root[5]:
                self.print_kd_tree(root[5])


if __name__ == '__main__':
    input_x = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]
    kd = KDTree(input_x)
    tree = kd.construct_kd_tree()
    kd.print_kd_tree(tree)



得到的运行结果如下, k d kd kd 树的层次也很清晰.

在这里插入图片描述

3.2.3 例3.3–利用搜索 k d kd kd 树求目标点的最近邻

因为例题将点抽象化了,故该例使用了例3.2的数据,并使用了sklearn包中带的KDTree函数.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from sklearn.neighbors import KDTree
np.random.seed(0)
# 随机产生150个二维数据
points = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
tree = KDTree(points)
point = np.array([4,5])
# k近邻法搜索,k=1时为最近邻
dists, indices = tree.query([point], k=1)

# q指定半径搜索
r = 2
indices = tree.query_radius([point], r)

fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
ax.add_patch(Circle(point, r, color='g', fill=False))
X, Y = [p[0] for p in points], [p[1] for p in points]
plt.scatter(X, Y)
plt.scatter([point[0]], [point[1]], c='r')
plt.show()


搜索结果如下图所示:

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值