【统计学习方法】K-近邻法

一、前言

K-近邻算法是一种基本的用于分类和回归的非参数统计方法,本篇blog将对分类问题中的k-近邻算法进行总结以及在文末给出了简单的Python实现。see more details in KNN.

二、K-近邻算法

K-近邻算法应用于分类问题时,算法具体内容很简单。首先需要注意的是:K-近邻算法是一种非参数统计方法
顾名思义,K-近邻算法不具备有显式的学习过程,当对新的样本进行分类时,K-近邻算法将计算该样本与训练集中所有样本的“距离”,并根据其k个最近邻的训练实例的类别,通过多数表决的方法来对新样本的类别进行预测:
y = arg ⁡ max ⁡ c j ∑ x i ∈ N k ( x ) I ( y i = c j ) , i = 1 , 2 , . . . , N , j = 1 , 2 , . . . , K {\rm{y}} = \mathop {\arg \max }\limits_{{c_j}} \sum\limits_{{x_i} \in {N_k}(x)} {I({y_i} = {c_j}),i = 1,2,...,N,j = 1,2,...,K} y=cjargmaxxiNk(x)I(yi=cj),i=1,2,...,N,j=1,2,...,K
其中, c j {c_j} cj为实例类别集合中类别, I I I为指示函数, N k ( x ) {N_k}(x) Nk(x)为在训练集中与样本最近邻的k个样本的集合, N N N为训练集中的样本总数。
K-近邻算法的三个基本要素是:k值的选择+距离度量+分类决策规则。分类决策规则采用多数表决的方法,距离度量有如欧氏距离、曼哈顿距离等方式,这里不再赘述,K-近邻算法的关键在于k值的选择,当k减少时,模型变得更加复杂,当k增加时,模型变得简单,实际应用时通常采用交叉验证法来对k值进行选择。

对给定的实例 x ∈ X x\in X xX,其最近邻的k个训练实例点构成集合 N k ( x ) {N_k}(x) Nk(x),如果涵盖 N k {N_k} Nk的区域类别是 c j {c_j} cj,那么误分类率是: 1 k ∑ x i ∈ N K ( x ) I ( y i ≠ c j ) = 1 − 1 k ∑ x i ∈ N K ( x ) I ( y i = c j ) \frac{1}{k}\sum\limits_{{x_i} \in {N_K}(x)} {I({y_i} \ne {c_j})} = 1 - \frac{1}{k}\sum\limits_{{x_i} \in {N_K}(x)} {I({y_i} = {c_j})} k1xiNK(x)I(yi=cj)=1k1xiNK(x)I(yi=cj)
要使得误分类率最小即经验风险最小,就要使得 ∑ x i ∈ N K ( x ) I ( y i = c j ) \sum\limits_{{x_i} \in {N_K}(x)} {I({y_i} = {c_j})} xiNK(x)I(yi=cj)最大,所以多数表决规则等价于经验风险最小化。

三、简单的Python实现

3.1 数据准备

from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

data, labels = make_blobs(n_samples=200, n_features=3, centers=2, random_state=42)
x_train, x_test, y_train, y_test = train_test_split(data, labels, train_size=0.9, random_state=42)

3.2 具体代码

k = 4

def distance(x, y):
	#欧氏距离
    return np.sqrt(np.sum(np.square((x-y).T), axis=0))

def knn(train_data, train_label, test_data, k):
    #knn具体实现
    pred_labels = []
    
    for i in test_data:
        temp = {}
        
        distance_list = distance(i, train_data)
        k_labels = train_label[np.argsort(distance_list)][:k]
        
        for i in range(k):
            if k_labels[i] in temp:
                temp[k_labels[i]] += 1
            else:
                temp[k_labels[i]] = 1
                
        pred_labels.append(max(temp, key=lambda x : temp[x]))
    
    return pred_labels

def acc(y_pred, y_true):
	#计算准确率
    print(sum(y_pred == y_true) / len(y_true))

四、K-近邻法的实现:kd树

上述给出了K-近邻法最简单的实现方法,即线性扫描,但是当训练集很大时,线性扫描的方法计算十分耗时。为了提高K-近邻搜索的效率,考虑使用特殊的数据结构存储训练数据,以减少计算距离的次数,比如下述的kd树。

4.1 kd树的构建

输入 D = { x 1 , x 2 , . . . , x N } D=\{{x_1},{x_2},...,{x_N}\} D={x1,x2,...,xN},其中 x i = ( x i ( 1 ) , x i ( 2 ) , x i ( 3 ) . . . , x i ( k ) ) T {x_i}=({\rm{x}}_i^{(1)}, {\rm{x}}_i^{(2)}, {\rm{x}}_i^{(3)}...,{\rm{x}}_i^{(k)})^T xi=(xi(1),xi(2),xi(3)...,xi(k))T i = 1 , 2 , . . . , N i=1,2,...,N i=1,2,...,N
输出:kd树

  1. 构造根节点,根节点对应于包含D的k维空间的超矩形区域:选择 x ( 1 ) {x^{(1)}} x(1)所对应的维度为坐标轴,利用D中所有实例的 x ( 1 ) {x^{(1)}} x(1)维度对应的坐标的中位数作为切分点。
  2. 对深度为j 的结点,选择 x ( l ) {x^{(l)}} x(l)作为切分的坐标轴,其中 l = j ( m o d ) k + 1 l=j(mod)k+1 l=j(mod)k+1,剩下的步骤与第一步相同。
  3. 重复上述递归,直到实例被划分完全,这样就完成了kd树的构建。

4.1.1 Python实现

KD树构建:

class KD_Node:
    def __init__(self, data, depth):
        self.data = data
        self.depth = depth
        self.right = None
        self.left = None
        
class KD_Tree:
    def __init__(self, data):
        self.data = data
        self.root = None
        
    def _build(self, points, depth):
        
        # 1、递归退出条件:直到两个子区域没有实例存在时停止,从而形成对kd树的区域划分
        if len(points) == 0:
            return None
        
        # 2、获取单个样本维度
        k = len(points[0])
        
        # 3、选择切分轴
        _axis = depth % k
        
        # 4、按样本_axis轴进行排序
        points.sort(key=lambda x : x[_axis])
        
        # 5、获得_axis轴上数据的中位数坐标
        median_idx = len(points) // 2
        
        # 6、利用median_idx来构造当前的“根节点”
        node = KD_Node(points[median_idx], depth)
        
        # 7、利用median_idx左边的数据构建左子树
        node.left = self._build(points[0:median_idx], depth+1)
        
        # 8、利用median_idx右边的数据构建右子树
        node.right = self._build(points[median_idx+1:], depth+1)
        
        return node
        
    def build(self):
        self.root = self._build(self.data, 0)
        return self.root
        
        
def preorder(root):
    print(root.data)
    if root.left:
        preorder(root.left)
    if root.right:
        preorder(root.right)
        
def inorder(root):
    if root.left:
        inorder(root.left)
    print(root.data)
    if root.right:
        inorder(root.right)

利用统计学习方法P54例3.2中给定的二维空间的数据集对KD树构建进行检查:

x = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]

完成KD树构建并对其进行前序和中序遍历得下:
在这里插入图片描述
前序遍历和中序遍历的结果均正确,KD树构建成功。

4.2 搜索KD树

4.2.1 搜索基本流程

首先给出KD树最近邻搜索算法的基本流程:
输入:构造好的KD树,以及目标点x
输出:目标点的最近邻

  1. 首先在KD树中寻找包含目标点的叶结点:从根节点出发,按照上述构建kd树depth和特征向量维度的关系改变样本数据比较的维度,在当前维度axis上,若目标点 x [ a x i s ] < n o d e [ a x i s ] x[axis]<node[axis] x[axis]<node[axis],则当前节点移动到左子结点,否则移动到右子节点,直到节点为叶节点为止。
  2. 以当前叶结点为“当前最近点”。
  3. 递归地向上回退,在每个结点上进行如下操作:
    a. 如果当前结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”。
    b.当前最近点存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体地,检查另一个子结点对应区域是否以目标点为球心、以目标点与“当前最近点“的距离为半径的超球体相交:如果相交,可能在另一个子结点对应的区域存在距离目标点更近的点,移动到另一个子结点,接着递归地进行最近邻搜索;不相交则向上回退。
  4. 回退到根节点时,搜索结束,最后的”当前最近点“即为 x x x的最近邻点。

4.2.2 算法步骤分解

视频讲解链接:KNN KD_Trees
图源:https://www.youtube.com/watch?v=oQQrxiJvnhw
根据上图,可以看到图片构建好的KD树以及目标点 ( 6 , 7 ) (6,7) (6,7)。在算法启动时,我们将最近距离(best distance)设置为 f l o a t [ " i n f " ] float["inf"] float["inf"]、将当前最近点设置为 N o n e None None
按照算法流程走,第一步先将目标节点与根节点进行比较:计算出目标结点与根结点的距离,并与保存的最近距离相比较。将最近距离更新为目标结点与根节点的距离,将当前最近点设置为根结点。
第二部即递归向下遍历各结点:此时,算法通过比较结点与目标点当前轴上的数据大小来决定走左子树或者右子树。可以很容易看出, a x i s = 0 axis=0 axis=0 6 < 7 6<7 6<7,此时结点向左子树移动。
图源:https://www.youtube.com/watch?v=oQQrxiJvnhw
接着是不断递归向下遍历直到当前结点为叶结点,在这个过程中,将计算每个节点与目标点的距离,如果这个距离小于上述所保存的最近距离,则将最近距离更新为这个距离,同时更新当前最近点
图源:https://www.youtube.com/watch?v=oQQrxiJvnhw
当我们遍历到叶结点时,需要特别注意的是:我们需要对所谓的”bad side of the tree"上的结点与目标点进行比较吗?答案是不需要,这也是算法中3(b)步:检查该子结点的父结点的另一个子结点对应的区域是否有更近的点
图源:https://www.youtube.com/watch?v=oQQrxiJvnhw
但在此处我们并不需要计算目标点与 ( 2 , 3 ) (2,3) (2,3)的距离,以上图为例,我们只需要计算目标点到“bad side of the tree“所对应的轴的距离,并将这个距离与当前保存的最近距离进行比较,因为这也是目标点距离这片区域最近的点(红色区域):
若该距离大于最近距离,那么就没有再去比较的必要。该步对应于算法中的:检查另一个子结点对应区域是否以目标点为球心、以目标点与“当前最近点“的距离为半径的超球体相交
若该距离小于最近距离,那么说明另一个结点所在的区域可能存在结点与目标点的距离小于当前最小距离,那么就需要转移到另一个结点,递归地进行最近邻搜索。
图源:https://www.youtube.com/watch?v=oQQrxiJvnhw

最后,回退到根节点时,搜索结束,最后的”当前最近点“即为 x x x的最近邻点。

4.2.3 最近邻搜索的Python实现

result = collections.namedtuple("result", "nearest_point nearest_dist")

def find_nearest(kd_tree, point):
    k = len(point) 
    
    def search(kd_node, target, max_dist):
        # 递归退出条件
        if kd_node is None:
            return result([0] * k, float("inf"))
        
        # 获取当前结点对应的切分轴以及当前结点存储的数据
        split = kd_node.depth % k
        node_data = kd_node.data
        
        # 在当前切分轴上,目标数据小于当前结点数据,则判定其离左子树更近,反之为右子树
        if target[split] <= node_data[split]:
            nearer_node = kd_node.left
            farther_node = kd_node.right
        else:
            nearer_node = kd_node.right
            farther_node = kd_node.left
            
        # 递归向下遍历直到叶结点
        templeaf = search(nearer_node, target, max_dist) # 直达叶节点
        
        # 以叶结点作为当前最近点,以叶结点与目标的距离作为当前最近距离
        nearest = templeaf.nearest_point
        dist = templeaf.nearest_dist
        
        # 如果当前最近距离小于最大距离,则更新最大距离
        if dist < max_dist:
            max_dist = dist
            
        # 判断“bad side of the tree”所对应的区域是否有可能存在离目标点更近的点
        temp_dist = abs(node_data[split] - target[split])
        
        # 如果max_dist小于temp_dist,则目标点到兄弟结点所在区域的最小距离仍大于当前最小距离,这种情况就不用对当前结点的兄弟结点进行考察
        if max_dist < temp_dist:
            return result(nearest, dist)
        
        # 否则就要对当前结点的兄弟结点进行考察,首先计算node_data与目标点之间的距离
        temp_dist = np.sqrt(sum((p1 - p2) ** 2 for p1,p2 in zip(node_data, target)))
        
        # 如果temp_dist小于当前最近距离,则更新当前最小距离和当前最近结点
        if temp_dist < dist:
            nearest = node_data
            dist = temp_dist
            max_dist = dist
        
        # 在兄弟结点上递归搜索
        temp2 = search(farther_node, target, max_dist)
        
        if temp2.nearest_dist < dist:
            nearest = temp2.nearest_point
            dist = temp2.nearest_dist
            
        return result(nearest, dist)
    
    
    return search(kd_tree, point, float("inf"))

4.3 KD_Tree上的KNN

4.3.1 算法描述

输入: 构造好的kd树;目标点x
输出: x的k个近邻

  1. 在kd树中找出包含目标点的叶结点:从根节点出发递归向下访问,目标点当前维度的数据小于当前结点当前维度的数据时移动到左子结点,反之则移动到右子结点,直到结点为叶结点。
  2. 若当前k近邻点集元素数量小于k或者叶结点距离小于”当前k近邻点集“中的最远距离,那么将叶结点插入到”当前k近邻点集“;
  3. 递归地向上回退:
    a. 如果当前k近邻点集元素数量小于k或者当前结点小于当前k近邻点集中最远点距离,那么将该结点插入当前k近邻点集
    b. 检查另一个子结点所对应的区域是否以目标点为球心、以目标点与当前k近邻点集中最远点间的距离为半径的超球体相交,如果相交,可能在另一个子结点对应的区域存在距离目标点更近的点,移动到另一个子结点,接着,递归地进行最近邻搜索;如果不相交,向上回退;
  4. 当回退到根结点时,搜索结束,最后的*“当前k近邻点集”*即为x的最近邻点。

4.3.2 Python实现

Code:DataWhale

import json


class Node:
    """节点类"""

    def __init__(self, value, index, left_child, right_child):
        self.value = value.tolist()
        self.index = index
        self.left_child = left_child
        self.right_child = right_child

    def __repr__(self):
        return json.dumps(self, indent=3, default=lambda obj: obj.__dict__, ensure_ascii=False, allow_nan=False)

class KDTree:
    """kd tree类"""

    def __init__(self, data):
        # 数据集
        self.data = np.asarray(data)
        # kd树
        self.kd_tree = None
        # 创建平衡kd树
        self._create_kd_tree(data)

    def _split_sub_tree(self, data, depth=0):
        # 算法3.2第3步:直到子区域没有实例存在时停止
        if len(data) == 0:
            return None
        # 算法3.2第2步:选择切分坐标轴, 从0开始(书中是从1开始)
        l = depth % data.shape[1]
        # 对数据进行排序
        data = data[data[:, l].argsort()]
        # 算法3.2第1步:将所有实例坐标的中位数作为切分点
        median_index = data.shape[0] // 2
        # 获取结点在数据集中的位置
        node_index = [i for i, v in enumerate(
            self.data) if list(v) == list(data[median_index])]
        return Node(
            # 本结点
            value=data[median_index],
            # 本结点在数据集中的位置
            index=node_index[0],
            # 左子结点
            left_child=self._split_sub_tree(data[:median_index], depth + 1),
            # 右子结点
            right_child=self._split_sub_tree(
                data[median_index + 1:], depth + 1)
        )

    def _create_kd_tree(self, X):
        self.kd_tree = self._split_sub_tree(X)

    def query(self, data, k=1):
        data = np.asarray(data)
        hits = self._search(data, self.kd_tree, k=k, k_neighbor_sets=list())
        dd = np.array([hit[0] for hit in hits])
        ii = np.array([hit[1] for hit in hits])
        return dd, ii

    def __repr__(self):
        return str(self.kd_tree)

    @staticmethod
    def _cal_node_distance(node1, node2):
        """计算两个结点之间的距离"""
        return np.sqrt(np.sum(np.square(node1 - node2)))

    def _search(self, point, tree=None, k=1, k_neighbor_sets=None, depth=0):
        if k_neighbor_sets is None:
            k_neighbor_sets = []
        if tree is None:
            return k_neighbor_sets

        # (1)找到包含目标点x的叶结点
        if tree.left_child is None and tree.right_child is None:
            # 更新当前k近邻点集
            return self._update_k_neighbor_sets(k_neighbor_sets, k, tree, point)

        # 递归地向下访问kd树
        if point[0][depth % k] < tree.value[depth % k]:
            direct = 'left'
            next_branch = tree.left_child
        else:
            direct = 'right'
            next_branch = tree.right_child
        if next_branch is not None:
            # (3)(a) 判断当前结点,并更新当前k近邻点集
            k_neighbor_sets = self._update_k_neighbor_sets(
                k_neighbor_sets, k, next_branch, point)
            # (3)(b)检查另一子结点对应的区域是否相交
            if direct == 'left':
                node_distance = self._cal_node_distance(
                    point, tree.right_child.value)
                if k_neighbor_sets[0][0] > node_distance:
                    # 如果相交,递归地进行近邻搜索
                    return self._search(point, tree=tree.right_child, k=k, depth=depth + 1,
                                        k_neighbor_sets=k_neighbor_sets)
            else:
                node_distance = self._cal_node_distance(
                    point, tree.left_child.value)
                if k_neighbor_sets[0][0] > node_distance:
                    return self._search(point, tree=tree.left_child, k=k, depth=depth + 1,
                                        k_neighbor_sets=k_neighbor_sets)

        return self._search(point, tree=next_branch, k=k, depth=depth + 1, k_neighbor_sets=k_neighbor_sets)

    def _update_k_neighbor_sets(self, best, k, tree, point):
        # 计算目标点与当前结点的距离
        node_distance = self._cal_node_distance(point, tree.value)
        if len(best) == 0:
            best.append((node_distance, tree.index, tree.value))
        elif len(best) < k:
            # 如果“当前k近邻点集”元素数量小于k
            self._insert_k_neighbor_sets(best, tree, node_distance)
        else:
            # 叶节点距离小于“当前 𝑘 近邻点集”中最远点距离
            if best[0][0] > node_distance:
                best = best[1:]
                self._insert_k_neighbor_sets(best, tree, node_distance)
        return best

    @staticmethod
    def _insert_k_neighbor_sets(best, tree, node_distance):
        """将距离最远的结点排在前面"""
        n = len(best)
        for i, item in enumerate(best):
            if item[0] < node_distance:
                # 将距离最远的结点插入到前面
                best.insert(i, (node_distance, tree.index, tree.value))
                break
        if len(best) == n:
            best.append((node_distance, tree.index, tree.value))

# 打印信息
def print_k_neighbor_sets(k, ii, dd):
    if k == 1:
        text = "x点的最近邻点是"
    else:
        text = "x点的%d个近邻点是" % k

    for i, index in enumerate(ii):
        res = X_train[index]
        if i == 0:
            text += str(tuple(res))
        else:
            text += ", " + str(tuple(res))

    if k == 1:
        text += ",距离是"
    else:
        text += ",距离分别是"
    for i, dist in enumerate(dd):
        if i == 0:
            text += "%.4f" % dist
        else:
            text += ", %.4f" % dist

    print(text)

import numpy as np

X_train = np.array([[2, 3],
                    [5, 4],
                    [9, 6],
                    [4, 7],
                    [8, 1],
                    [7, 2]])
kd_tree = KDTree(X_train)
# 设置k值
k = 1
# 查找邻近的结点
dists, indices = kd_tree.query(np.array([[3, 4.5]]), k=k)
# 打印邻近结点
print_k_neighbor_sets(k, indices, dists)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值