机器学习实战:KNN算法讲解

机器学习实战:KNN算法讲解

    KNN算法本章内容来至于《统计学习与方法》李航,《机器学习》周志华,以及《机器学习实战》Peter HarringTon,相互学习,不足之处请大家多多指教

    1.1 KNN算法的优缺点

    1.2 KNN算法的工作机制

    1.3 KNN算法的python实现

    1.4 我对KNN算法的理解

1.1 KNN算法的优缺点

      优点:KNN算法是分类数据最简单的算法,具有精度高,对异常值不明显,无数据输入假定的特点。

      缺点:KNN算法必须保存全部的数据,如果训练的数据集比较大,必须使用大量的存储空间,而且对每个数据距离计算,可能会比较耗时,KNN算法的另一个缺陷是无法给出任何数据的基础结构信息,无法知道实例样本和典型样本具有什么特征。

1.2 KNN算法的工作机制

【1】KNN算法:给定测试样本,基于某种距离度量找到训练集中最靠近的K个训练样本,然后基于这K个邻居的信息来进行预测,通常在分类任务重可以使用“投票法”,即选择这K个样本中出现最多的类别标记作为预测结果,在回归任务中可以使用平均分,将k个样本的实值输出标记的平均值作为预测值,或者是积极与距离远近进行加权平均或者加权投票,距离越近的样本权重越大。-----周志华 《机器学习P225页》

输入:训练数据T = {(x1,y1),(x2,y2),(x3,y3),(x4,y4),……(xn,yn)},

实例的类别y={c1,c2,c3……,cn},以及实例向量x。

输出:实例x所属的类别y。

算法过程:

(1)根据给定的距离度量,在训练集T中找出与实例x最近的K个点,涵盖这k个点的x的领域记为Nk(x)

(2)在Nk(x)中,根据分类决策规则,如多数表决,决定x的类别

(3)K近邻算法的特殊情况是K=1的情况,称为最近邻算法,对于输入的实例点,最近邻算法将训练数据集中与x最近点的类作为x的类

 

【2】距离度量包括LP距离,欧氏距离,曼哈顿距离,《统计学习与方法》

其中xi,xj的LP距离定义为:

 

当P=2时候,称为欧氏距离:

 

当P= 1时候,称为曼哈顿距离:

 

当P=无穷大时候,他是坐标的最大值:

  

【3】关于K值的选择对KNN算法的影响

如果选择较小的K值,就相当于用较小的领域中的训练实例进行预测,学习的近似误差会减小,只有输入实力和相似点的训练实例较近时候,才会对预测起结果,但学习的误差估计会增大,预测结果对对近邻的实例点非常敏感,如果近邻是噪声点就会出错,换句话说K值变小,会使得整体模型变得复杂,容易发生过拟合

如果K值比较大,就相当于用较大的领域中的训练实例进行预测,其优点是会减少学习的估计误差,但是缺点是学习的近似误差会增大,这时候与输入实例较远的点也会对训练实例起预测作用,使得预测发生错误K值的增大会使得整个模型变得更加简单。

在训练过程中,K值通常比较小,通常采用交叉验证法来选取合适的K值

    

1.3 KNN算法的python实现

参照机器学习实战的例子,使用KNN算法改进约会网站的配对效果

#!/usr/bin/python
#-*- encoding:utf-8 -*-
from numpy import  *
import  numpy as np
import operator
import matplotlib as mpl
import matplotlib.pyplot as plt

#添加Linux黑体字库,避免matplotlib显示中文乱码
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False

def createDataSet():
    #训练的数据T
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    #分类的标签label
    labels = ['A','A','B','B']
    return  group,labels

def classify0(inx,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]
    print dataSet
    diffMat = tile(inx,(dataSetSize,1)) - dataSet
    print  diffMat
    sqDiffMat = diffMat **2
    sqDistances = sqDiffMat.sum(axis=1)
    print 'sqDistances =',sqDistances
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
    sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
    return sortedClassCount[0][0]

#从文本数据中获得数据
def file2matrix(filename):
    fr = open(filename)
    arrayOlines = fr.readlines();
    numberOfLines = len(arrayOlines)
    returnMat = zeros((numberOfLines,3))
    classLabelVector = []
    index = 0
    for line in arrayOlines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:]=listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index +=1
    return returnMat,classLabelVector


#归一化空间
def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals -minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet-tile(minVals,(m,1))
    normDataSet = normDataSet/tile(ranges,(m,1))
    return normDataSet,ranges,minVals


def datingClassTest():
    hoRatio = 0.10
    datingDataMat,datingLabels = file2matrix('datingTestSet.txt')
    normMat,ranges,minVals = autoNorm(datingDataMat)
    m=normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print 'the classifier came back with: %d ,the real answer is :%d'%(classifierResult,datingLabels[i])
        if(classifierResult!= datingLabels[i]):errorCount+=1.0
    print "the total error rate is :%f"%(errorCount/float(numTestVecs))

if __name__ == "__main__":
    # datingClassTest()
    group,labels = createDataSet();
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
    fig =plt.figure(facecolor='w')
    ax = fig.add_subplot(111)
    ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))

    plt.xlabel(u"玩游戏所耗时间百分比",fontsize=14)
    plt.ylabel(u"每周消耗的冰淇淋公斤升数",fontsize=14)
    plt.title(u"约会网站KNN算法预测")
    plt.show()


实验的结果:


 

代码技巧

1:归一化特征

 

2:使用多通道颜色显示不同的类别

 

代码调试过程中出现的bug

 

代码下载:

 

 

  • 0
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值