机器学习实战:K近邻算法(kNN)

算法伪代码:

1、计算输入向量intX距训练集中各点的距离

2、将距离按从小到大排序

3、选取最小的k个值

4、统计其中各类标签数目

5、选取标签出现最多的,作出预测


代码如下:

#intX为输入向量,k为选取最邻近的k个点    
def classify0(intX,dataset,labels,k):
    diff=dataset-intX                 #运用了广播机制,使dataset的每一行都减去输入向量
    diff2=diff**2
    sum1=sum(diff2,axis=1)            #按行相加  
    distance=sum1**0.5
    sortedDis=distance.argsort()      #argsort函数返回排序索引  
    dic={}
    #选取据输入向量最近的k个点,统计标签数目
    for i in range(k):
        voteLabel=labels[sortedDis[i]]
        dic[voteLabel]=dic.get(voteLabel,0)+1  #字典的get函数:Get an element with a default  
        
    sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedDic[0][0]
sorted函数:排序后原序列顺序不变

sorted(iterable, cmp=None, key=None, reverse=False)

operator.itemgetter(1)通过比较第二个数据成员来排序


示例:使用k近邻算法改进约会网站的配对效果

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt

#intX为输入向量,k为选取最邻近的k个点    
def classify0(intX,dataset,labels,k):
    diff=dataset-intX                 #运用了广播机制,使dataset的每一行都减去输入向量
    diff2=diff**2                     #取平方  
    sum1=sum(diff2,axis=1)            #按行相加  
    distance=sum1**0.5                #开方  
    sortedDis=distance.argsort()      #argsort函数返回排序索引  
    dic={}                            #字典  
    #选取据输入向量最近的k个点,统计标签数目
    for i in range(k):
        voteLabel=labels[sortedDis[i]]
        dic[voteLabel]=dic.get(voteLabel,0)+1  #字典的get函数:Get an element with a default  
        
    sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedDic[0][0]
    
def file2matrix(filename):
    fr=open(filename)
    fileLines=fr.readlines()          #将文件的每一行读取后,作为List中的一个元素。每个元素为一个字符串    
    numberOfLines=len(fileLines)      #文件总共行数(List的长度)  
    returnMat=zeros((numberOfLines,3))  #用0初始化一个m*n的数组(numpy数组)
    classLabelVector=[]
    index=0
    
    for line in fileLines:
        line=line.strip()             #Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)。   
        splitLine=line.split('\t')    #split()通过指定分隔符对字符串进行切片,如果参数num 有指定值,则仅分隔 num 个子字符串
                                      #此为用'\t'将字符串分隔成数组(List)形式  形如:['38343', '7.241614', '1.661627', '3\n']
        returnMat[index,:]=splitLine[0:3]
        classLabelVector.append(int(splitLine[-1]))
        index+=1
    return returnMat,classLabelVector
           
def autoNorm(dataset):
    maxOfCol=dataset.max(0)           #计算列最大值    
    minOfCol=dataset.min(0)
    max_minOfCol=maxOfCol-minOfCol
    normDataSet=(dataset-minOfCol)/max_minOfCol
    return normDataSet,minOfCol,max_minOfCol

def draw():
    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt')
    #调用figure创建一个绘图对象,并且使它成为当前的绘图对象
    fig = plt.figure()             
    #绘制的图像在1*1的网格中占第一部分(即占满)
    ax = fig.add_subplot(111)      
    #datingDataMat矩阵中第二列为x,第三列为y
    ax.scatter(datingDataMat[:,0], datingDataMat[:,1],15.0*array(datingLabels), 15.0*array(datingLabels))
    plt.xlabel("frequent flier miles")
    plt.ylabel("percentage of time spent playing video games")
    plt.show()

def datingClassTest():
    testRatio=0.1
    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt') 
    normDataSet,Min,ranges=autoNorm(datingDataMat)
    tNum=int(testRatio*(normDataSet.shape[0]))
    trainingSet=normDataSet[tNum:,:]
    trainingLabels=datingLabels[tNum:]
    error=0.0
    k=4
    
    for i in range(tNum):
        testResult=classify0(normDataSet[i,:],trainingSet,trainingLabels,k)
        print "the test result is %d, the real is %d" %(testResult,datingLabels[i])
        if testResult != datingLabels[i]:
            error+=1.0
    print "error number is %d, error rate is %f" %(error,error/float(tNum))

def classifyPerson():
    datingDataMat,datingLabels=file2matrix('datingTestSet2.txt') 
    normDataSet,Min,ranges=autoNorm(datingDataMat)
    testRatio=0.1
    tNum=int(testRatio*(normDataSet.shape[0]))
    trainingSet=normDataSet[tNum:,:]
    trainingLabels=datingLabels[tNum:]
    
    arg1=input("frequent flier miles earned per year?")
    arg2=input("percentage of time spent playing video games?")
    arg3=input("liters of ice cream consumed per year?")
    intX=array([arg1,arg2,arg3])
    testResult=classify0((intX-Min)/ranges,trainingSet,trainingLabels,3)
    labelTable=['not at all','in small doses','in large doses']
    print "You will probably like this person: ",labelTable[testResult-1]

#draw()
#datingClassTest()
classifyPerson()

手写数字识别系统的测试代码

from numpy import *
from os import listdir

def img2vector(filename):
    fr=open(filename)
    lines=fr.readlines()
    mat=zeros((1,1024))               #numpy数组
    k=0
    for line in lines:
        for j in range(32):
            mat[0][k]=int(line[j])    #numpy数组应将一维显示地表示成m*n形式
            k+=1
    return mat
    
#intX为输入向量,k为选取最邻近的k个点    
def classify0(intX,dataset,labels,k):
    diff=dataset-intX                 #运用了广播机制,使dataset的每一行都减去输入向量
    diff2=diff**2                     #取平方  
    sum1=sum(diff2,axis=1)            #按行相加  
    distance=sum1**0.5                #开方  
    sortedDis=distance.argsort()      #argsort函数返回排序索引  
    dic={}                            #字典  
    #选取据输入向量最近的k个点,统计标签数目
    for i in range(k):
        voteLabel=labels[sortedDis[i]]
        dic[voteLabel]=dic.get(voteLabel,0)+1  #字典的get函数:Get an element with a default  
        
    sortedDic=sorted(dic.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedDic[0][0]

def handwritingClassTest():
    trainFileList=listdir('trainingDigits')    #将目录下所有文件的文件名用List表示,每一个文件名为List中一个元素
    testFileList=listdir('testDigits')
    m=len(trainFileList)
    n=len(testFileList)
    error=0.0
    trainMat=zeros((m,1024))
    
    realResult=[]
    for i in range(m):
        mat=img2vector('trainingDigits/'+trainFileList[i])
        trainMat[i]=mat
        realResult.append(int(trainFileList[i][0]))
        
        
    for i in range(n):
        testMat=img2vector('testDigits/'+testFileList[i])
        testResult=classify0(testMat,trainMat,realResult,3)
        testRealResult=int(testFileList[i][0])
        print 'test result is %d, real result is %d' %(testResult,testRealResult)
        if testResult != testRealResult:
            error+=1
            
    print 'there is %d error, the error rate is %f' %(error,error/float(n))    
        
handwritingClassTest()    

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值