机器学习基础算法-k近邻算法

算法概述

算法定义:k近邻采用计算预测点与样本数据特征的距离,选取k个距离最近的样本标签(这个标签其实就是数据的分类,这个数据属于哪个类别,比如某一条电影数据前面有好几个特征值,后边这个标签就是标记这条电影数据属于那种类型的电影,例如属于爱情类电影,这个标签就是爱情),找到k个距离最近的样本标签后,统计这k个样本重出现次数最多的那个分类。
算法条件:需要有样本数据,由于需要计算距离,样本数据特征之间的数值大小相差不应该太大,比如某个特征数值为10000,另外一个特征数值为0.11,这样计算出来的距离和第一个特征正相关,其他特征被屏蔽,这样的样本数据就需要对样本数据进行归一化处理,可以将数据转化为0-1之间的数值

可以使用这个公式: (x-min)/(max-min)

  • 式中的x表示当前被归一化的数据
  • 式子中的min是这该特征数据的最小值
  • max是这个特征数据最大值

通过上面的式子,我们就能够把这个特征数据进行归一化处理
算法实现过程

  1. 准备样本数据,样本数据最好能够有一定的格式,这样在进行程序处理的时候就能快速的处理了
  2. 样本数据归一化(如果需要的话)
  3. 输入待分类数据
  4. 计算带分类数据与各条样本数据的距离
  5. 找到前k个距离最近的样本数据
  6. 找到这个k个样本数据中,出现次数最多的标签

找到计算次数最多的标签,即为这个待分类数据的预测分类

算法实例

下面介绍一个书中的实例,利用样本数据进行手写数字识别
这个例子中的样本数据是二值化的数据这个就是数字0二值化结果
通过0-9的10个手写样本数据进行加载,计算与待分类样本数据距离,通过距离最近的k个样本,估计待测样本数据的预测值。
在计算中大量使用了numpy这个python的科学计算包,进行矩阵的运算

from testpackage import test
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
from os  import listdir

def classify0(inx,dataset,lables,k):
    datasetSize =  dataset.shape[0]
    diffdataset = tile(inx,(datasetSize,1))-dataset
    sqdiffdataset = diffdataset**2
    sqdiffdistances = sqdiffdataset.sum(axis=1)
    distances = sqdiffdistances**0.5
    sorteddistance = distances.argsort()
    classcount = {}
    for i in range(k):
        votelable = lables[sorteddistance[i]]
        classcount[votelable] = classcount.get(votelable,0) +1   
    sortedClassCount =  sorted(classcount.items(),key=operator.itemgetter(0),reverse = True)
    return sortedClassCount[0][0]
    
def datingfileToMatrix(filename):    
    fr = open(filename)
    datinglines = fr.readlines()
    datinglength = len(datinglines)
    datingMat = zeros((datinglength,3))
    classLablesV = []
    index = 0
    for line in datinglines:
        row = line.strip()
        listfromLine = row.split('\t')
        datingMat[index,:] = listfromLine[0:3]
        classLablesV.append(int(listfromLine[-1]))
        index+=1
    return datingMat,classLablesV

def autoNorm(dataSet):
    minvaules = dataSet.min(0)
    maxvalues = dataSet.max(0)
    rangevalue = maxvalues - minvaules
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minvaules,(m,1))
    normDataSet = normDataSet/tile(rangevalue,(m,1))
    return normDataSet,rangevalue,minvaules

def datingClassTest():
    hoRatio = 0.50      #hold out 10%
    datingDataMat,datingLabels = datingfileToMatrix('datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio) 
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print(errorCount)
    
def img2vector(filename):
    file = open(filename)
    returnvector = zeros((1,1024))
    for i in range(32):
        lineStr = file.readline()
        for j in range(32):
            returnvector[0,i*32+j] = int(lineStr[j])
    return returnvector

def handwritingclassTest():
    hwlables = []
    trainingfiles = listdir('digits/trainingDigits')
    m = len(trainingfiles)
    trainmat = zeros((m,1024))
    for i in range(m):
        filename = trainingfiles[i]
        fileStr = filename.split('.')[0]
        rightnumber = int(fileStr.split('_')[0])
        hwlables.append(rightnumber)
        trainmat[i,:] = img2vector('digits/trainingDigits/{0}'.format(filename))
    testfilelist = listdir('digits/testDigits')
    errorcount = 0.0
    mTest = len(testfilelist)
    for i in range(mTest):
        filename = testfilelist[i]
        fileStr = filename.split('.')[0]
        rightNumber = int(fileStr.split('_')[0])
        testVector = img2vector('digits/testDigits/{0}'.format(filename))
        classr = classify0(testVector,trainmat,hwlables,3)
        print('the classifier came back is {0},the real answer is : {1}'.format(classr,rightNumber))
        if classr != rightNumber:
            errorcount+=1
    print('the total error number is {0}'.format(errorcount))
    print('the total error rate is {0}'.format(errorcount/float(mTest)))
    
if __name__ == "__main__":
    group,lables = createDataset()
    r = classify0([0,0],group,lables,3)
    print(r)
    datingmat,classlab = datingfileToMatrix("datingTestSet2.txt")
    normMat,matranges,matminvals = autoNorm(datingmat)
#    print(normMat)
#    print(matranges)
#    print(matminvals)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(datingmat[:,0],datingmat[:,1],15.0*array(classlab),15.0*array(classlab))
    plt.show()
#    datingClassTest()
#    handwritingclassTest()
总结

k近邻算法的主要优点:

  • 对异常数据不敏感

缺点就是计算复杂度太高了,空间复杂度相对较高,他需要将带预测的数据与样本数据逐个计算距离,如果样本数据达到几百万条,这样的计算量是非常大的

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值