k-近邻算法详解-附代码

 

项目概述

海伦使用约会网站寻找约会对象。经过一段时间之后,她发现曾交往过三种类型的人:

  • 不喜欢的人
  • 魅力一般的人
  • 极具魅力的人

她希望:

  1. 工作日与魅力一般的人约会
  2. 周末与极具魅力的人约会
  3. 不喜欢的人则直接排除掉

现在她收集到了一些约会网站未曾记录的数据信息,这更有助于匹配对象的归类

40920    8.326976    0.953952    3    432
14488    7.153469    1.673904    2     124
26052    1.441871    0.805124    1     45
75136    13.147394    0.428964    1     34
38344    1.669788    0.134296    1      12 

23432    1.453445    0.123453    1      0

现在根据上面我们得到的样本集中所有人与未知人群的距离,按照距离递增排序,可以找到 k 个距离最近的人。 假定 k=3,发现倒数三个,是距离最近的,随即得出,未知人的label是1

KNN 原理

KNN 工作原理

  1. 假设有一个带有标签的样本数据集(训练样本集),其中包含每条数据与所属分类的对应关系。
  2. 输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较。
    1. 计算新数据与样本数据集中每条数据的距离。
    2. 对求得的所有距离进行排序(从小到大,越小表示越相似)。
    3. 取前 k (k 一般小于等于 20 )个样本数据对应的分类标签。
  3. 求 k 个数据中出现次数最多的分类标签作为新数据的分类。

KNN 通俗理解

给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的 k 个实例,这 k 个实例的多数属于某个类,就把该输入实例分为这个类。

算法思路有了,以下是实现。

将文本记录转换为 NumPy 的解析程序

def getDataArr(path):
    file = open(path, "r")
    dataLines = file.readlines()
    dataLen = len(dataLines)
    dataLabels = []
    idx = 0
    dataList = np.zeros((dataLen, 3))
    for line in dataLines:
        line = line.strip()
        data = line.split()
        dataList[idx, :] = data[:3]
        dataLabels.append(int(data[-1]))
        idx = idx + 1
    return np.array(dataList), np.array(dataLabels)

分析数据: 使用 Matplotlib 画二维散点图

def getPlot(dataList, dataLabels):
    plt.scatter(dataList[:, 0], dataList[:, 1], c = dataLabels)

归一化数据 (归一化是一个让权重变为统一的过程)

方法有如下:

  1. 线性函数转换,表达式如下:   

    y=(x-MinValue)/(MaxValue-MinValue)  

    说明: x、y分别为转换前、后的值,MaxValue、MinValue分别为样本的最大值和最小值。  

  2. 对数函数转换,表达式如下:   

    y=log10(x)  

    说明: 以10为底的对数函数转换。

    如图:

  3. 反余切函数转换,表达式如下:

    y=arctan(x)*2/PI 

    如图:

     

  4. 式(1)将输入值换算为[-1,1]区间的值,在输出层用式(2)换算回初始值,其中和分别表示训练样本集中负荷的最大值和最小值。 

在统计学中,归一化的具体作用是归纳统一样本的统计分布性。归一化在0-1之间是统计的概率分布,归一化在-1--+1之间是统计的坐标分布。

def autoNorm(dataList):
    minVal = dataList.min(0)
    maxVal = dataList.max(0)
    #归一化
    rangeVal = maxVal - minVal
    normDataList = np.zeros(np.shape(dataList))
    m = dataList.shape[0]
    
    normDataList = dataList - np.tile(minVal, (m, 1))
    normDataList = normDataList / np.tile(rangeVal, (m, 1))
    
    return normDataList, rangeVal, minVal
kNN 算法伪代码:
对于每一个在数据集中的数据点: 
    计算目标的数据点(需要分类的数据点)与该数据点的距离
    将距离排序: 从小到大
    选取前K个最短距离
    选取这K个中最多的分类类别
    返回该类别来作为目标数据点的预测值

def classifyFun(inX, dataList, dataLabels, k):
    dataListLen = dataList.shape[0]
    difMat = np.tile(inX, (dataListLen, 1)) - dataList
    sqDifMat = difMat ** 2
    sqDifMatSum = sqDifMat.sum(axis = 1)
    DifDist = sqDifMatSum ** 0.5
    # 排序
    DifDistSortIdx = DifDist.argsort()
    classCount = {}
    for i in range(k):
        getLabel = dataLabels[DifDistSortIdx[i]]
        classCount[getLabel] = classCount.get(getLabel, 0) + 1
    classifyAns = sorted(classCount.items(), key = operator.itemgetter(1),reverse = True)
    return classifyAns[0][0]

主函数的inX输入你需要计算的类别

    filePath = "datingTestSet2.txt"
    inX = [15669, 0.000000, 1.250185]
    k = 5
    dataList, dataLabels = getDataArr(filePath)
    getPlot(dataList, dataLabels)
    autoNorm(dataList)
    preLable = classifyFun(inX, dataList, dataLabels, k)
    print("labels = ", preLable)
    plt.scatter(inX[0], inX[1], c = preLable, marker = '+', s = 1000)

计算结束。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值