KNN算法以及kd树搜索k个节点的实现

K N N KNN KNN(K-nearest neighbor)

K K K近邻算法

  • 输入:训练数据集

T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋅ ⋅ ⋅ , ( x N , y N ) } T =\{(x_1,y_1),(x_2,y_2),···,(x_N,y_N)\} T={(x1,y1),(x2,y2),,(xNyN)}

​ 其中, x i ∈ R n x_i\in R^n xiRn 为实例的特征向量, y i ∈ Y = { c 1 , c 2 , c 3 , ⋅ ⋅ ⋅ , c K } y_i\in Y = \{c_1,c_2,c_3,···,c_K\} yiY={c1,c2,c3,,cK} 为实例的类别, i = 1 , 2 , ⋅ ⋅ ⋅ , N i = 1,2,···,N i=1,2,,N

​ 输出:实例 x x x 所属的类 y y y

  • 算法过程:
    1. 根据给定的距离度量,在训练集 T T T 中找出与 x x x 最邻近的 k k k 个点,涵盖这 k k k 个点的 x x x 领域记作 N k ( x ) N_k(x) Nk(x);
    2. N k ( x ) N_k(x) Nk(x) 中根据分类决策规则(例如多数表决)决定 x x x 的类别 y y y

K K K近邻模型

  • 距离度量:

    特征空间中两个实例点的距离是两个实例点相似程度的反映。 k k k 近邻模型的特征空间一般是 n n n维实数向量空间 R n R^n Rn。使用的距离一般是欧式距离,也可以是其他距离,比如** L p L_p Lp距离或者 M i n k o w s k i Minkowski Minkowski距离**。

    设特征空间 X X X n n n维实数向量空间 R n R^n Rn x i , x j ∈ X x_i,x_j \in X xi,xjX x i = ( x i ( 1 ) , x i ( 2 ) , ⋅ ⋅ ⋅ , x i ( n ) ) x_i = (x_i^{(1)},x_i^{(2)},···,x_i^{(n)}) xi=(xi(1),xi(2),,xi(n)) x j = ( x j ( 1 ) , x j ( 2 ) , ⋅ ⋅ ⋅ , x j ( n ) ) x_j = (x_j^{(1)},x_j^{(2)},···,x_j^{(n)}) xj=(xj(1),xj(2),,xj(n)),则 x i , x j x_i,x_j xi,xj L p L_p Lp距离定义为:
    L p ( x i , x j ) = ( ∑ l = 1 n ∣ x i ( l ) − x j ( l ) ∣ p ) 1 p ( p ≥ 1 ) L_p(x_i,x_j) = (\sum_{l=1}^n|x_i^{(l)-x_j^(l)}|^p)^{\frac{1}{p}} (p \geq 1) Lp(xi,xj)=(l=1nxi(l)xj(l)p)p1(p1)

    • p = 2 p = 2 p=2 时,称为欧式距离 ( E u c l i d e a n (Euclidean (Euclidean d i s t a n c e ) distance) distance)
    • p = 1 p = 1 p=1 时,称为曼哈顿距离 ( M a n h a t t a n (Manhattan (Manhattan d i s t a n c e ) distance) distance)
    • p = ∞ p = \infty p= 时,它是个个坐距离的最大值

    选择不同的距离度量对预测结果的影响不同

  • k k k值选择:

    k k k的选择会对 k k k近邻结果产生重大影响

    k k k值较小学习的“近似误差”会减小,但是“估计误差”会增大,如果 k k k实例点所在的领域恰好包含噪声,预测则会出错; k k k值较大可以减少“估计误差”,但是“近似误差会增大”,也就是说离实例点较远的训练点也会对其产生一定影响。

  • 分类决策规则:

    一般采用多数表决的方法

K K K近邻的实现

暴力解法(线性查找)

以下代码参考于Li-hang

import numpy as np
import pandas as pd
from collections import Counter
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

class KNN():
    def __init__(self,X_train,Y_train,n_neighbors = 3,p = 2,):
        """
        :params
        	n_neighbors:前面所讨论的k
        	p:Lp距离参数
        """
        self.X_train = X_train
        self.Y_train = Y_train
        self.p = p
        self.n_neighbors = n_neighbors

    def predict(self,X):
        """
        :param
            X: 输入空间中的变量
        :return
            Y: 预测的结果
        """
        # knn_list存放筛选出来的n个结果
        knn_list = []

        # 先选出n个凑满knn_list
        for i in range(self.n_neighbors):
            # 计算X和训练集中前n个的距离,范数为p
            distance = np.linalg.norm(abs(X-self.X_train[i]),ord = self.p)
			knn_list.append((distance,self.Y_train[i]))


        # 遍历整个训练集寻找最小的n个值
        for i in range(self.n_neighbors,len(self.X_train)):
            # 计算距离
            distance = np.linalg.norm(abs(X-self.X_train[i]),ord = self.p)
            # 找到最大值的下标,knn_list元素是元组,比较对象为元组的第一个元素
            max_index = knn_list.index(max(knn_list,key= lambda x:x[0]))

            # 如果distance小于knn_list中的最大值,则替换
            if knn_list[max_index][0] > distance:
                knn_list[max_index] = (distance,self.Y_train[i])

        # n个最小距离的标签值
        labels = [k[1] for k in knn_list]
        # 返回一个Counter对象,含有label以及其出现的次数(dict)
        label = Counter(labels)

        # 将label中的字典元素根据出现次数即x[1]从大到小排序,并且取Y为第一个字典中的第一个元素
        Y = sorted(label.items(),key=lambda x: x[1],reverse = True)[0][0]
        return Y

    def score(self,X_test,Y_test):
        """
        :param
            X_test:测试数据集输入
            Y_test:测试数据集target
        :return
            返回估计正确率
        """

        # 计算正确的个数
        counter = 0
        for index in range(len(X_test)):
            # 预测结果
            Y = self.predict(X_test[index])
            # 预测结果等于真实值则counter增加
            if Y == Y_test[index]:
                counter+=1
        # 返回正确率
        return counter/len(X_test)
kd树

kd树初识

kd树搜索

# 以下代码包含kd树的建立以及kd树搜索k个近邻点

# kd树节点
class KdNode():
    def __init__(self,split,left,right,value):
        self.value = value
        self.left = left
        self.right = right
        self.split = split
        self.visited = False

# kd树
class KdTree():

    def __init__(self,data):
        # k是维度
        k = len(data[0])


        # 创建一颗kd树
        def CreateTree(split,data_set:np.ndarray):

            # 如果为空返回None
            if  len(data_set) == 0:
                return None
            # 将data_set.argsort(axis = 0)按照列从小到大排序
            # data_set.argsort(axis = 0)[:,split]选取第split列(这一列是位置)
            # data_set[data_set.argsort(axis=0)[:,split]]按照列次序重排
            data_set = data_set[data_set.argsort(axis=0)[:,split]]

            # 找到中点分隔
            split_pos = len(data_set)//2
            # 中点的数据
            median = data_set[split_pos]
            # 下一次的分割点
            split_next = (split+1)%k

            # 左
            left = CreateTree(split_next,data_set[:split_pos])
            # 右
            right = CreateTree(split_next,data_set[split_pos+1:])
            # 当前节点的值
            node = KdNode(split,left,right,median)
            return node

        self.root = CreateTree(0,data)

    # 先序非递归遍历kd树
    def preorder(self):

        lst = []
        tree = self.root
        while tree or len(lst):

            while tree:
                print(tree.value)

                if tree.right:
                    lst.append(tree)
                tree = tree.left

            if len(lst):
                tree = lst.pop().right

    # 利用kd树找出X的k个最近点
    def search(self,X,p = 2,n_neighbors = 1):
        # knn_list用来存放(distance,value)
        knn_list = []

        # 递归寻找
        def Recursive(node:KdNode):

            # 节点为空直接返回
            if not node:
                return True

            # node没被访问过
            if node.visited == False:
                
                # 左枝
                if X[node.split]<=node.value[node.split]:
                    Recursive(node.left)
                # 右枝
                elif X[node.split] > node.value[node.split]:
                    Recursive(node.right)
            # 访问过则直接返回
            else:
                return True

            # 标记为访问过
            node.visited = True

            # 如果knn_list中不够n_neighbors个元素则向其中添加元素
            if len(knn_list) < n_neighbors:
                # 距离
                dist = np.linalg.norm(abs(X - node.value), ord=p)
                # 添加(distance,value)元组
                knn_list.append((dist,node.value))

                # 判断另一枝
                if X[node.split] <= node.value[node.split]:
                    Recursive(node.right)
                else:
                    Recursive(node.left)

            # 如果knn_list中已经有n_neighbors个元素
            else:
                # 离边界的距离
                edge_dist = abs(X[node.split]-node.value[node.split])

                # knn_list中的最大距离所在的元组
                max_dist = max(knn_list,key=lambda x:x[0])

                # 如果离边界的距离大于最大距离,则在另一边不可能有更小的值
                if edge_dist > max_dist[0]:
                    return True

                # 如果离边界的距离小于最大距离,则当前节点以及另一边可能存在更小值
                else:

                    # 当前节点和X的距离
                   dist = np.linalg.norm(abs(X - node.value), ord=p)
                   # 如果当前节点和X距离更小
                   if dist < max_dist[0]:
                       # 找到下标
                       index = knn_list.index(max_dist)
                       # 替换
                       knn_list[index] = (dist,node.value)

                   # 寻找另一枝
                   if X[node.split] <= node.value[node.split]:
                        Recursive(node.right)
                   else:
                        Recursive(node.left)

            return True

        # 深拷贝副本,便于多次使用
        root = copy.deepcopy(self.root)


        Recursive(root)
        return knn_list

读者可自行将kd树建立搜索稍加修改应用到KNN算法中

关注公众号我们共同进步

关注公众号我们共同进步
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值