KNN(K近邻)算法 简单实现

1.简介

K近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。Cover和Hart在1968年提出了最初的邻近算法。KNN是一种分类(classification)算法,它输入基于实例的学习(instance-based learning),属于懒惰学习(lazy learning)即KNN没有显式的学习过程,也就是说没有训练阶段,数据集事先已有了分类和特征值,待收到新样本后直接进行处理。与急切学习(eager learning)相对应。

KNN是通过测量不同特征值之间的距离进行分类。

思路是:如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

提到KNN,网上最常见的就是下面这个图,可以帮助大家理解。
  在这里插入图片描述

我们要确定绿点属于哪个颜色(红色或者蓝色),要做的就是选出距离目标点距离最近的k个点,看这k个点的大多数颜色是什么颜色。当k取3的时候,我们可以看出距离最近的三个,分别是红色、红色、蓝色,因此得到目标点为红色。

算法的描述:
1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类

参考:
https://blog.csdn.net/zaishuiyifangxym/article/details/95311111
https://www.cnblogs.com/jyroy/p/9427977.html

2.代码实现

import numpy as np
import matplotlib.pyplot as plt
import imageio

data = [] ##存储已知数据
undata = []  ##待处理数据点
cls = []	##数据类别 0/1
##计算欧式距离
def distEuclid(x,y):
    return np.sqrt(np.sum((x-y)**2))

##随机生成若干待分类数据 rg位数据值范围
def genUnknownData(n,rg):
    while len(undata)<n:
        p = np.around(np.random.rand(2)*2*rg,decimals=2)
        undata.append(p)
 
##随机生成两类已知数据       
def genOriData(n,rg):
    while len(data)<n:
        p = np.around(np.random.rand(2)*rg,decimals=2)
        data.append(p)
    cls.extend([0 for i in range(n)])
    while len(data)<2*n:
        p = np.around(np.random.rand(2)*rg+rg,decimals=2)
        data.append(p)
    cls.extend([1 for i in range(n)])
    
## knn算法实现
def Knn(k):
    for x in undata:
        dis = {}
        ##计算未知点x与已知点的距离
        for i in range(len(data)):
            dis[i] = distEuclid(x,data[i])
            
        ##找到最近k个点 根据k个点的分类情况决定点x的类别
        res=sorted(dis.items(),key=lambda x:x[1])[:k]
        clsn=0
        for (i,j) in res:
            clsn += cls[i]
        if clsn/k>=0.5:
            cls.append(1)
        else:
            cls.append(0)
            
        data.append(x)
        Show()

##图片展示      
def Show():
    color = ['r','g','b','c','y','m','k']
    for i in range(len(data)-1):
        mark = int(cls[i])
        plt.plot(data[i][0],data[i][1],color[mark]+'o')
    mark = int(cls[len(data)-1])
    plt.plot(data[len(data)-1][0],data[len(data)-1][1],color[mark]+'x')

    plt.show()
    plt.pause(1)
  
genUnknownData(10,5)
genOriData(5,5)
k=3
Show()
Knn(k)

结果如下:
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值