knn算法的实现

本文参照:李航《统计学习方法》

k近邻法

k近邻法是一种基本的分类和回归的方法。

K近邻思想

k近邻的思想简单的来说就是一个投票选举,给定一个训练数据集(所有已经投票的人),输入一个新的实例(一个还未投出的票),去寻找与这个实例中最邻近的K个实例(和这个没有投出票最近的K个人),这K个实例的多数来自于某个类(投票的某个对象),这个新实例就是属于这个类(从众,和周围多数人保持一致)。

在这个过程中,有一个最重要的问题就是对K值的选择,K值过小,就是容易产生过拟合的现象(在数据的收集过程中一些噪音点,会对预测产生很大的影响),K值如果过大,会导致近似误差增大(范围大,就将那些不相似的数据吸收进来,使得预测发生错误)。

在下图1中,K如果选择蓝环内,x点就属于红点类,k的范围扩大到黄圈,x点就属于蓝三角类,再次扩大,又属于红点类,由此不难看出k的选择非常重要,选择不同,预测结果不同。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
# 简单脚本绘制同心圆
r = [2.0,6.0,10.0]
a, b = (0., 0.)
cc = np.arange(0, 2*np.pi, 0.01)
plt.figure() 
x1 = [1,0.2,9]
y1 = [5,0.2,2]
x2 = [-3,-5,4]
y2 = [-2,3,-1]
plt.scatter(x1,y1, color='r', s=25, marker="o")
plt.scatter(x2,y2, color ='b',s =35, marker="^")
plt.scatter(0,-0.5, color ='g',s =45, marker="x")
for i in r:
    x = a + i * np.cos(cc)
    y = b + i * np.sin(cc)
    plt.plot(x, y)
    plt.axis('equal')
plt.title("图1")
plt.show()

png

K近邻法的实现:kd树

k近邻在实现的时候,就是在不停的对实例点进行搜索,计算距离,比较远近,这样当数据量特别大的时候,如果采用线性检索,就得花费大量的时间,这是不可取的,为了提高搜索效率,就需要考虑使用特殊的存储结构来存储训练数据,减少计算距离的次数,其中有一种就为kd树。

kd树在k近邻算法中分为两个过程

1.构建平衡Kd树

2.使用kd树进行最近邻搜索

构建kd树

1.构建根节点,使根节点对应k维空间中包含所有实例点的超巨型区域

2.递归调用,不断的进行对k维空间进行切分,生成子节点,因为是构建的平衡kd树,所以我们将实例中x的中位数作为切分点。

说明:以下的程序实现中我们用到的是二维空间,也就是说特征值只有两个为 x,y,xi(x,y) x , y , x i ( x , y ) ,这样操作的时候先去企业 x x 的中维数做为切分点,然后下一个切分点用y的中位数….以此类推,直到所有的实例点被切分完成。多维同样如此。

数据集: T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)} T = { ( 2 , 3 ) , ( 5 , 4 ) , ( 9 , 6 ) , ( 4 , 7 ) , ( 8 , 1 ) , ( 7 , 2 ) }

import numpy as np
class Node: #结点结构
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild
def sort(dataSet, axis): 
     '''
    输入:原始数据集,划分轴
    输出:排序好的结果
    描述:采用冒泡排序,利用axis作为轴进行划分
    '''

    sortDataSet = dataSet[:]   
    m, n = np.shape(sortDataSet)
    for i in range(m):
        for j in range(0, m - i - 1):
            if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                temp = sortDataSet[j]
                sortDataSet[j] = sortDataSet[j+1]
                sortDataSet[j+1] = temp
    print("\n排序结果->%s" % sortDataSet)
    return sortDataSet

def create(dataSet, depth): 
     '''
    输入:原始数据集,树的深度
    输出:创建好的树,返回根节点
    描述:将排序好的数据,通过中位数,将数据分为两部分,得到左子树和和右子树,递归执行
    '''

    if (len(dataSet) > 0):
        m, n = np.shape(dataSet)    #求出样本行,列
        midIndex = int(m / 2) #中间数的索引位置
        axis = depth % n    #判断以哪个轴划分数据
        sortedDataSet = sort(dataSet, axis) #进行排序
        node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本

        leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2改副本
        rightDataSet = sortedDataSet[midIndex+1 :]
        print("\n左子树->%s" % leftDataSet)
        print("\n右子树->%s" % rightDataSet)
        node.lchild = create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
        node.rchild = create(rightDataSet, depth+1)
        return node
    else:
        return None
dataSet = [[2, 3],
           [5, 4],
           [9, 6],
           [4, 7],
           [8, 1],
           [7, 2]]   
 tree = create(dataSet, 0)
排序结果->[[2, 3], [4, 7], [5, 4], [7, 2], [8, 1], [9, 6]]

左子树->[[2, 3], [4, 7], [5, 4]]

右子树->[[8, 1], [9, 6]]

排序结果->[[2, 3], [5, 4], [4, 7]]

左子树->[[2, 3]]

右子树->[[4, 7]]

排序结果->[[2, 3]]

左子树->[]

右子树->[]

排序结果->[[4, 7]]

左子树->[]

右子树->[]

排序结果->[[8, 1], [9, 6]]

左子树->[[8, 1]]

右子树->[]

排序结果->[[8, 1]]

左子树->[]

右子树->[]

def preOrder(node): 
    '''
    输入:树
    输出:遍历结果
    描述:如果节点不为空,就按照前序的方式遍历
    '''
    if node != None:
        print("遍历结果->%s" % node.data)
        preOrder(node.lchild)
        preOrder(node.rchild)

preOrder(tree)
遍历结果->[7, 2]
遍历结果->[5, 4]
遍历结果->[2, 3]
遍历结果->[4, 7]
遍历结果->[9, 6]
遍历结果->[8, 1]
搜索kd树

我们都直到二叉树的搜索,当数据量大的时候,搜索的效率很高,这是因为他可以通过剪枝的方式,减少对很多数据据点的检索,从而减少了计算的数量,kd树作为一种二叉树同样也有这种性质,这才是我们为什么要构建二叉树的原因。上面我们已近构建好了二叉树,下面呢,我们就是实现对二叉树的检索,给一个新的实例点进行预测。

搜索二叉树de过程

1.给出要预测的目标点;从根节点出发,递归的向下访问kd树,如果该目标点的坐标点小于切分点,就移动到左孩纸,否则移动到右孩子,一直到子节点为叶节点为止。

2.将这个叶节点保存为“当前最近节点”。

3.递归的向上回退,在没给节点做如下操作:

a) 如果该节点保存的实例点比当前最近点距离目标点更近,则将该实例点保存为“当前最近点”。

b)当前最近点一定从在于该节点一个子节点对应的区域,是否有更近的点,具体得,检查另一个子节点对应的区域是否与以目标点为球心,以目标点与“当前最近点”的距离为半径的超球体相交(本测试点因为是二维的,这个超球体就是一个圆),如果相交,可能在另一个子节点对应的区域内存在距离目标点更为接近的点,就移动到另外一个节点上,接着在另外的节点上进行最近邻搜索,反之,如果没有相交,就继续回退。

4.当回退到根节点时,搜索结束。最后保存的“当前最近点”就是目标点的最邻点。

目标点

基于上面的训练数据,给出测试点: x=[5,3] x = [ 5 , 3 ]

nearestPoint = None    #保存最近的点
nearestValue = 0   #保存最近的值 
count = 0 # 检索次数

def search(tree, x):  
    '''
    输入:树,要判断的实例点
    输出:得到最近节点
    '''
    global nearestPoint ,nearestValue,count
    travel(tree)
    return nearestPoint

def dist(a1, a2):
     '''
    输入:两个坐标点
    输出:算距离结果
    描述:通过欧式距离计算
    '''
    return ((np.array(a1) - np.array(a2)) ** 2).sum() ** 0.5
def travel(node, depth = 0):   

    '''
    输入:树
    输出:得到最近节点
    描述:二叉树搜索
    '''
    global nearestPoint ,nearestValue,count
    count +=1
    if node != None:    #递归终止条件
        n = len(x)  #特征数
        axis = depth % n    #计算轴
        if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
            travel(node.lchild, depth+1)
        else:
            travel(node.rchild, depth+1)

        #以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
        distNode_X = dist(x, node.data)  #目标和节点的距离判断
        if (nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
            nearestPoint = node.data
            nearestValue = distNode_X
        elif (nearestValue > distNode_X):
            nearestPoint = node.data
            nearestValue = distNode_X

        print("————————第%s次检索————————"%count)
        print("搜索节点->%s"%node.data)
        print("搜素深度->%s"%depth)
        print("最近节点->%s"%nearestPoint)
        print("最近值->%s"%nearestValue)
        print("节点的值->%s"%node.data[axis])
        print("目标点的值->%s"% x[axis])
        #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
        if (abs(x[axis] - node.data[axis]) <= nearestValue):  
            if x[axis] < node.data[axis]:
                travel(node.rchild, depth+1)
            else:
                travel(node.lchild, depth + 1)
    return nearestPoint
x=[5,3]
re = search(tree, x)
print("\n最近邻点%s"%re)
————————第4次检索————————
搜索节点->[2, 3]
搜素深度->2
最近节点->[2, 3]
最近值->3.0
节点的值->2
目标点的值->5
————————第5次检索————————
搜索节点->[5, 4]
搜素深度->1
最近节点->[5, 4]
最近值->1.0
节点的值->4
目标点的值->3
————————第7次检索————————
搜索节点->[4, 7]
搜素深度->2
最近节点->[5, 4]
最近值->1.0
节点的值->4
目标点的值->5
————————第8次检索————————
搜索节点->[7, 2]
搜素深度->0
最近节点->[5, 4]
最近值->1.0
节点的值->7
目标点的值->5

最近邻点[5, 4]

End!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值