机器学习经典算法详解及Python实现--K近邻(KNN)算法

转载http://blog.csdn.net/suipingsp/article/details/41964713

(一)KNN依然是一种监督学习算法

KNN(K Nearest Neighbors,K近邻 )算法是机器学习所有算法中理论最简单,最好理解的。KNN是一种基于实例的学习,通过计算新数据与训练数据特征值之间的距离,然后选取K(K>=1)个距离最近的邻居进行分类判断(投票法)或者回归。如果K=1,那么新数据被简单分配给其近邻的类。KNN算法算是监督学习还是无监督学习呢?首先来看一下监督学习和无监督学习的定义。对于监督学习,数据都有明确的label(分类针对离散分布,回归针对连续分布),根据机器学习产生的模型可以将新数据分到一个明确的类或得到一个预测值。对于非监督学习,数据没有label,机器学习出的模型是从数据中提取出来的pattern(提取决定性特征或者聚类等)。例如聚类是机器根据学习得到的模型来判断新数据“更像”哪些原数据集合。KNN算法用于分类时,每个训练数据都有明确的label,也可以明确的判断出新数据的label,KNN用于回归时也会根据邻居的值预测出一个明确的值,因此KNN属于监督学习。
KNN算法的过程为:
  1. 选择一种距离计算方式, 通过数据所有的特征计算新数据与已知类别数据集中的数据点的距离
  1. 按照距离递增次序进行排序,选取与当前距离最小的k个点
  1. 对于离散分类,返回k个点出现频率最多的类别作预测分类;对于回归则返回k个点的加权值作为预测值

(二)KNN算法关键

KNN算法的理论和过程就是那么简单,为了使其获得更好的学习效果,有下面几个需要注意的地方。
1、数据的所有特征都要做可比较的量化。
若是数据特征中存在非数值的类型,必须采取手段将其量化为数值。举个例子,若样本特征中包含颜色(红黑蓝)一项,颜色之间是没有距离可言的,可通过将颜色转换为灰度值来实现距离计算。另外,样本有多个参数,每一个参数都有自己的定义域和取值范围,他们对distance计算的影响也就不一样,如取值较大的影响力会盖过取值较小的参数。为了公平,样本参数必须做一些scale处理,最简单的方式就是所有特征的数值都采取归一化处置。
2、需要一个distance函数以计算两个样本之间的距离。
距离的定义有很多,如欧氏距离、余弦距离、汉明距离、曼哈顿距离等等,关于相似性度量的方法可参考‘漫谈:机器学习中距离和相似性度量方法’。一般情况下,选欧氏距离作为距离度量,但是这是只适用于连续变量。在文本分类这种非连续变量情况下,汉明距离可以用来作为度量。通常情况下,如果运用一些特殊的算法来计算度量的话,K近邻分类精度可显著提高,如运用大边缘最近邻法或者近邻成分分析法。
3,确定K的值
K是一个自定义的常数,K的值也直接影响最后的估计,一种选择K值得方法是使用 cross-validate(交叉验证)误差统计选择法交叉验证的概念之前提过,就是数据样本的一部分作为训练样本,一部分作为测试样本,比如选择95%作为训练样本,剩下的用作测试样本。通过训练数据训练一个机器学习模型,然后利用测试数据测试其误差率。 cross-validate(交叉验证)误差统计选择法就是比较不同K值时的交叉验证平均误差率,选择误差率最小的那个K值。例如选择K=1,2,3,... ,   对每个K=i做100次交叉验证,计算出平均误差,然后比较、选出最小的那个

(三)KNN分类

训练样本是多维特征空间向量,其中每个训练样本带有一个类别标签(喜欢或者不喜欢、保留或者删除)。分类算法常采用“多数表决”决定,即k个邻居中出现次数最多的那个类作为预测类。“多数表决”分类的一个缺点是出现频率较多的样本将会主导测试点的预测结果,那是因为他们比较大可能出现在测试点的K邻域而测试点的属性又是通过K领域内的样本计算出来的。解决这个缺点的方法之一是在进行分类时将K个邻居到测试点的距离考虑进去。例如,若样本到测试点距离为d,则选1/d为该邻居的权重(也就是得到了该邻居所属类的权重),接下来统计统计k个邻居所有类标签的权重和,值最大的那个就是新数据点的预测类标签。
举例,K=5,计算出新数据点到最近的五个邻居的举例是(1,3,3,4,5),五个邻居的类标签是(yes,no,no,yes,no)
若是按照多数表决法,则新数据点类别为no(3个no,2个yes);若考虑距离权重类别则为yes(no:2/3+1/5,yes:1+1/4)。
下面的Python程序是采用KNN算法的实例(计算欧氏距离,多数表决法决断):一个是采用KNN算法改进约会网站配对效果,另一个是采用KNN算法进行手写识别。
约会网站配对效果改进的例子是根据男子的每年的飞行里程、视频游戏时间比和每周冰激凌耗量三个特征来判断其是否是海伦姑娘喜欢的类型(类别为很喜欢、一般和讨厌),决策采用多数表决法。由于三个特征的取值范围不同,这里采用的scale策略为归一化。
使用KNN分类器的手写识别系统 只能识别数字0到9。需要识别的数字使用图形处理软件,处理成具有相同的色 彩和大小 :宽髙是32像素X32像素的黑白图像。尽管采用文本格式存储图像不能有效地利用内存空间,为了方便理解,这里已经将将图像转换为文本格式。训练数据中每个数字大概有200个样本,程序中将图像样本格式化处理为向量,即一个把一个32x32的二进制图像矩阵转换为一个1x1024的向量。
[python]  view plain  copy
 print ?
  1. from numpy import *  
  2. import operator  
  3. from os import listdir  
  4. import matplotlib  
  5. import matplotlib.pyplot as plt  
  6. import pdb  
  7.   
  8. def classify0(inX, dataSet, labels, k=3):  
  9.     #pdb.set_trace()  
  10.     dataSetSize = dataSet.shape[0]  
  11.     diffMat = tile(inX, (dataSetSize,1)) - dataSet  
  12.     sqDiffMat = diffMat**2  
  13.     sqDistances = sqDiffMat.sum(axis=1)  
  14.     distances = sqDistances**0.5  
  15.     sortedDistIndicies = distances.argsort() #ascend sorted,  
  16.     #return the index of unsorted, that is to choose the least 3 item      
  17.     classCount={}            
  18.     for i in range(k):  
  19.         voteIlabel = labels[sortedDistIndicies[i]]  
  20.         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1# a dict with label as key and occurrence number as value  
  21.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
  22.     '''''descend sorted according to value, '''  
  23.     return sortedClassCount[0][0]  
  24.   
  25.   
  26. def file2matrix(filename):  
  27.     fr = open(filename)  
  28.     #pdb.set_trace()  
  29.     L = fr.readlines()  
  30.     numberOfLines = len(L)         #get the number of lines in the file  
  31.     returnMat = zeros((numberOfLines,3))        #prepare matrix to return  
  32.     classLabelVector = []                       #prepare labels return         
  33.     index = 0  
  34.     for line in L:  
  35.         line = line.strip()  
  36.         listFromLine = line.split('\t')  
  37.         returnMat[index,:] = listFromLine[0:3]  
  38.         classLabelVector.append(int(listFromLine[-1]))  
  39.         #classLabelVector.append((listFromLine[-1]))  
  40.         index += 1  
  41.     fr.close()  
  42.     return returnMat,classLabelVector  
  43.   
  44. def plotscattter():  
  45.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file  
  46.     fig = plt.figure()  
  47.     ax1 = fig.add_subplot(111)  
  48.     ax2 = fig.add_subplot(111)  
  49.     ax3 = fig.add_subplot(111)  
  50.     ax1.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))  
  51.     #ax2.scatter(datingDataMat[:,0],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))  
  52.     #ax2.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))  
  53.     plt.show()  
  54.       
  55.       
  56. def autoNorm(dataSet):  
  57.     minVals = dataSet.min(0)  
  58.     maxVals = dataSet.max(0)  
  59.     ranges = maxVals - minVals  
  60.     normDataSet = zeros(shape(dataSet))  
  61.     m = dataSet.shape[0]  
  62.     normDataSet = dataSet - tile(minVals, (m,1))  
  63.     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide  
  64.     return normDataSet, ranges, minVals  
  65.      
  66. def datingClassTest(hoRatio = 0.20):  
  67.     #hold out 10%  
  68.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file  
  69.     normMat, ranges, minVals = autoNorm(datingDataMat)  
  70.     m = normMat.shape[0]  
  71.     numTestVecs = int(m*hoRatio)  
  72.     errorCount = 0.0  
  73.     for i in range(numTestVecs):  
  74.         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)  
  75.         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])  
  76.         if (classifierResult != datingLabels[i]): errorCount += 1.0  
  77.     print "the total error rate is: %.2f%%" % (100*errorCount/float(numTestVecs))  
  78.     print 'testcount is %s, errorCount is %s' %(numTestVecs,errorCount)  
  79.   
  80. def classifyPerson():  
  81.     ''''' 
  82.     input a person , decide like or not, then update the DB 
  83.     '''  
  84.     resultlist = ['not at all','little doses','large doses']  
  85.     percentTats = float(raw_input('input the person\' percentage of time playing video games:'))  
  86.     ffMiles = float(raw_input('flier miles in a year:'))  
  87.     iceCream = float(raw_input('amount of iceCream consumed per year:'))  
  88.     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')  
  89.     normMat, ranges, minVals = autoNorm(datingDataMat)  
  90.     normPerson = (array([ffMiles,percentTats,iceCream])-minVals)/ranges  
  91.     result = classify0(normPerson, normMat, datingLabels, 3)  
  92.     print 'you will probably like this guy in:', resultlist[result -1]  
  93.   
  94.     #update the datingTestSet  
  95.     print 'update dating DB'  
  96.     tmp = '\t'.join([repr(ffMiles),repr(percentTats),repr(iceCream),repr(result)])+'\n'  
  97.   
  98.     with open('datingTestSet2.txt','a') as fr:  
  99.         fr.write(tmp)  
  100.   
  101. def img2file(filename):  
  102.     #vector = zeros(1,1024)  
  103.     with open(filename) as fr:  
  104.         L=fr.readlines()  
  105.     vector =[int(L[i][j]) for i in range(32for j in range(32)]  
  106.     return array(vector,dtype = float)  
  107.           
  108.   
  109. def handwritingClassTest():  
  110.     hwLabels = []  
  111.     trainingFileList = listdir('trainingDigits')           #load the training set  
  112.     m = len(trainingFileList)  
  113.     trainingMat = zeros((m,1024))  
  114.     for i in range(m):  
  115.         fileNameStr = trainingFileList[i]  
  116.         fileStr = fileNameStr.split('.')[0]     #take off .txt  
  117.         classNumStr = int(fileStr.split('_')[0])  
  118.         hwLabels.append(classNumStr)  
  119.         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)  
  120.     testFileList = listdir('testDigits')        #iterate through the test set  
  121.     errorCount = 0.0  
  122.     mTest = len(testFileList)  
  123.     for i in range(mTest):  
  124.         fileNameStr = testFileList[i]  
  125.         fileStr = fileNameStr.split('.')[0]     #take off .txt  
  126.         classNumStr = int(fileStr.split('_')[0])  
  127.         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)  
  128.         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  
  129.         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)  
  130.         if (classifierResult != classNumStr): errorCount += 1.0  
  131.     print "\nthe total number of errors is: %d" % errorCount  
  132.     print "\nthe total error rate is: %f" % (errorCount/float(mTest))  
  133.   
  134. if __name__ == '__main__':  
  135.     datingClassTest()  
  136.     #handwritingClassTest()  

KNN算法学习包下载地址为:

Machine Learning K近邻算法

(四)KNN回归

数据点的类别标签是连续值时应用KNN算法就是回归,与KNN分类算法过程相同,区别在于对K个邻居的处理上。KNN回归是取K个邻居类标签值得加权作为新数据点的预测值。加权方法有:K个近邻的属性值的平均值(最差)、1/d为权重(有效的衡量邻居的权重,使较近邻居的权重比较远邻居的权重大)、高斯函数(或者其他适当的减函数)计算权重= gaussian(distance) (距离越远得到的值就越小,加权得到更为准确的估计。

(五)总结

K-近邻算法是分类数据最简单最有效的算法,其学习基于实例,使用算法时我们必须有接近实际数据的训练样本数据。K-近邻算法必须保存全部数据集,如果训练数据集的很大,必须使用大量的存储空间。此外,由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。k-近邻算法的另一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。

本文作者Adan,来源于:机器学习经典算法详解及Python实现--K近邻(KNN)算法。转载请注明出处。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值