kNN算法识别手写数字(代码笔记)

k-近邻算法,属于有监督分类算法。

思想:利用输入数据特征值和训练样本数据特征值之间的距离分类,挑出距离最小的k个训练样本的类别频率,作为预测的分类估计。

'''
k-近邻算法是基于实例的学习
1 使用时要保存全部的数据集,占存储空间
2 要对每个训练数据计算距离值,实际使用时非常耗时
'''
import numpy as np
import operator

def classify0(x, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(x, (dataSetSize,1)) - dataSet
    sqDiff = diffMat**2
    sqDist = sqDiff.sum(axis=1)
    distances = sqDist**0.5  # 一行数据的平方根
    sortedDistInd = distances.argsort()  # 向量元素从小到大对应的索引号
    classCount = {}
    for i in range(k):  # 前k个,也就是最近的k个; 统计类出现的频率
        vLabel = labels[sortedDistInd[i]]  
        classCount[vLabel] = classCount.get(vLabel,0)+1
    sortedClassCount = sorted(classCount.items(), # 转成dict_items:[(key1,cnt1),(key2,cnt2),..]
                       key=operator.itemgetter(1), # 排序,依据tuple第二个元素;reverse,由大到小
                       reverse=True)
    return sortedClassCount[0][0]
    
def img2vec(filename):  # 32x32的矩阵数据转成向量
    vec = np.zeros((1,1024))
    fr = open(filename)  # (如果是txt文件的话)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            vec[0, 32*i+j] = int(lineStr[j])
    return vec

def handwritingClassify():
    trainLab = []
    trainFileList = listdir('trainingDigits')  # 训练数据目录
    m = len(trainFileList)
    trainMat = zeros((m,1024))  # 训练数据存成一个矩阵
    for i in range(m):
        filenameStr = trainFileList[i]
        fileStr = filenameStr.split('.')[0]
        classStr = int(fileStr.split('_')[0])
        trainLab.append(classStr)
        trainMat[i,:] = img2vec('trainingDigits/%s' % filenameStr)
    #------------------------ 测试数据 -------------------------
    errorCount = 0.0
    testFileList = listdir('testDigits')  # 测试数据目录
    n = len(testFileList)
    for i in range(n):
        filenameStr = testFileList[i]
        fileStr = filenameStr.split('.')[0]
        classStr = int(fileStr.split('_')[0])
        vecTest = img2vec('testDigits/%s' % filenameStr)
        classTest = classify0(vecTest, trainMat, trainLab, 3)  # 测试数据的直接分类
        print("the classifier predicts : %d, the real is : %d" % (classTest,classStr))
        if(classTest!=classStr):
            error += 1.0
    print("\n the total numbers of errors is: %d" % errorCount)
    print("\n the total error rate is: %d" % (error/float(n)))


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值