kNN算法学习笔记

假期开始学习Peter Harrington的《机器学习实战》一书,为督促自己学习,同时也分享一下经历,打算开始写下这一系列文章记录。

kNN算法,也是k-邻近算法,是一个分类算法,该算法的思想也比较简单,我们事先知道一些数据,也知道每条数据对应的分类关系,输入没有分类的新数据后,将新数据与已知数据中的每条数据进行比较,选择k个与之最相似(距离最近)的样本,新数据分类与这k个样本中多数类别相同,通常k值位于20以内。

kNN算法的一般流程如下:

1、收集数据;

2、准备数据:将数据以自己需要的格式进行准备;

3、分析数据:通过画图等一系列手段分析数据,查看大概规律等;

4、训练算法:kNN算法属于惰性学习,没有显示学习过程;

5、测试算法:用测试集测试,计算错误概率,更改参数,减小错误率;

6、使用算法。

根据一个手写识别系统的例子来说明以上步骤:

根据提供的图像文本,根据kNN算法分类,识别出是哪一个数字,因为数据集只有0~9,因此也只能识别0~9,需要识别的图像已经处理为黑白图像,如下所示:


图像为32*32的黑白图像,需要转化为1*1024的向量,编写下面函数进转化:

def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])

    return returnVect

编写kNN算法实现代码,需要numpy库进行一系列矩阵运算,需要operator库进行迭代器参数输入,如下:

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances ** 0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

编写算法测试代码,用测试数据集对算法进行测试,需要os库进行目录操作,如下:

def handwriteClassTest():
    hwLabels = []
    k = 5
    trainFileList = listdir('trainingDigits')
    trainFileListSize = len(trainFileList)
    trainMat = zeros((trainFileListSize, 1024))
    for i in range(trainFileListSize):
        fileNameStr = trainFileList[i]
        classNumStr = int(fileNameStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainMat[i, :] = img2vector('trainingDigits/' + fileNameStr)
    testFileList = listdir('testDigits')
    testFileListSize = len(testFileList)
    error = 0.0
    for i in range(testFileListSize):
        fileNameStr = testFileList[i]
        classNumStr = int(fileNameStr.split('_')[0])
        testVector = img2vector('testDigits/' + fileNameStr)
        classifierResult = classify0(testVector, trainMat, hwLabels, k)
        if classifierResult != classNumStr:
            error += 1
        print('classify result:%d, real result:%d' %(classifierResult, classNumStr))
    errorRate = error/float(testFileListSize)
    print('error number is: %d' %error)
    print('error rate is: %f' %errorRate)

测试结果如下:



在测试代码中,k=5,可以更改k值,来降低测试的错误率,这需要多次调试才能得到比较合适的k值。

完整代码可以在以下github中查看:

https://github.com/resistzzz/kNN

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值