KNN算法进阶


前言

之前一篇文章里主要对KNN的代码进行了测试,而这篇文章是为了弥补理论的不足。

一、算法分析

算法图解

先举一个简单但是很有特点的例子,如下图所示:
KNN实例

我们的样本空间中有11个样本(6蓝和5红),对于不确定的分类(绿色),我们找到离其最近的k个点,通过出现次数更多的颜色来确定绿色待测样本的分类。这就有一个很有意思的现象,图中,分别选择了2个k值(k=3和k=5)k=3时,分类中红色有2个,蓝色只有1个,因此会将绿色划分到成红色类。而k=5时,蓝色则多于红色,会将绿色分到蓝色类。
从这个小例子中,可以得出k近邻算法的三个基本要素:度量方式、k值的选择、分类决策规则。

1.度量方式

K近邻模型的特征空间可被抽象成n维的空间向量v,现在两个对象之间的距离就可转换为两个向量间的距离,在这里我采用欧氏距离:
在这里插入图片描述

2.k值的选择

  • 当k值为一个时,此算法被成为最近邻算法。
  • 如果选择较小的k值,相当于用大部分的样本用于训练,学习近似误差减小,容易发生过拟合。
  • 如果选择较大的k值,相当于用小部分的样本用于训练,学习近似误差增大,容易发生欠拟合。
  • 在实际应用中,k值一般取一个比较小的数值,可以用交叉验证法来选择最优的k值。

    3.分类决策规则

    k近邻算法中大多采用多数表决,即由输入对象的k个邻近中的多数类决定输入对象的类。

    二、测试算法

    1.约会配对

    def file2matrix(filename):
        fr = open(filename)
        arrayOfLines = fr.readlines()
        numberOfLines = len(arrayOfLines)
        returnMat = zeros((numberOfLines, 3))
        classLabelVetor = []
        index = 0
        for line in arrayOfLines:
            line = line.strip()
            listFromLine = line.split('\t')
            returnMat[index, :] = listFromLine[0:3]
            classLabelVetor.append(int(listFromLine[-1]))
            index += 1
        fr.close()
        return returnMat, classLabelVetor
    
    
    def autoNorm(dataSet):
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet = zeros(shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - tile(minVals, (m, 1))
        normDataSet = normDataSet / tile(ranges, (m, 1))
        return normDataSet, ranges, minVals
    
    
    def datingClassTest():
        hoRatio = 0.05
        datingDateMat, datingLabels = file2matrix('D:\python\PyCode\machinelearninginaction\Ch02\datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDateMat)
        m = normMat.shape[0]
        numTestVecs = int(m * hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
            print("the classifier came back with:%d, the real answer is:%d" % (classifierResult, datingLabels[i]))
            if (classifierResult != datingLabels[i]):
                errorCount += 1.0
        print("the total error rate is: %f" % (errorCount / float(numTestVecs)))
        
    if __name__ == '__main__':
        group, labels = createDataSet()
        # print(group,end="\n")
        # print(labels)
        # inp = [5,5.6]
        # print("样本:" + str(inp))
        # print(classify0(inp, group, labels, 3))
    
        # datingDateMat, datingLabels = file2matrix('D:\python\PyCode\machinelearninginaction\Ch02\datingTestSet2.txt')
        # fig = plt.figure()
        # ax = fig.add_subplot(111)
        # ax.scatter(datingDateMat[:, 1], datingDateMat[:, 2], 15.0 * array(datingLabels), 15.0 * array(datingLabels))
        # plt.show()
        datingClassTest()
    

    在这里插入图片描述

    2.手写体识别

    def img2vector(filename):
        returnVect = zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect
    
    
    def handwritingClassTest():
        hwLabels = []
        trainingFileList = listdir('D:\python\PyCode\machinelearninginaction\Ch02\\testDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector('D:\python\PyCode\machinelearninginaction\Ch02\\testDigits/%s' % fileNameStr)
        testFileList = listdir('D:\python\PyCode\machinelearninginaction\Ch02\\testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vetorUnderTest = img2vector('D:\python\PyCode\machinelearninginaction\Ch02\\testDigits/%s' % fileNameStr)
            classifierResult = classify0(vetorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came backwith: %d,the real answer is:%d" % (classifierResult, classNumStr))
            if classifierResult != classNumStr:
                errorCount += 1.0
        print("\n the total number of errors is:%d" % errorCount)
        print("\n the total error rate is: %f" % (errorCount / float(mTest)))
    
    
    if __name__ == '__main__':
        group, labels = createDataSet()
        # print(group,end="\n")
        # print(labels)
        # inp = [5,5.6]
        # print("样本:" + str(inp))
        # print(classify0(inp, group, labels, 3))
    
        # datingDateMat, datingLabels = file2matrix('D:\python\PyCode\machinelearninginaction\Ch02\datingTestSet2.txt')
        # fig = plt.figure()
        # ax = fig.add_subplot(111)
        # ax.scatter(datingDateMat[:, 1], datingDateMat[:, 2], 15.0 * array(datingLabels), 15.0 * array(datingLabels))
        # plt.show()
        # datingClassTest()
    
        handwritingClassTest()
    

    总结

    该算法在分类时有个主要的不足是:当样本不平衡时,当一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。因此可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点,因此应该可以先去除对判断样本分类用处不大的样本。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值