用 KNN 做手写数字识别
目录
作为一个小白,写此文章主要是为了自己记录,方便回过头来查找! 本文主要参考ApacheCN(专注于优秀项目维护的开源组织)中MachineLearning中的KNN项目。有些代码参考这个项目(大部分),有些代码是自己写的。建议去ApacheCN去学习,还有专门的视频讲解,个人感觉非常好。下面对利用KNN进行手写数字识别的过程进行简要的描述:
1. KNN的原理
KNN的原理,本文不做解释,想做了解的人可以去ApacheCN上的项目进行学习或者观看对应视频学习。
2. KNN实现手写数字识别过程
本文主要是测试了一下在测试集上的准确度。测试集样本个数为946个(数据集同样可以在ApacheCN上面进行下载),训练集样本个数为1934个(0~9),其样本保存方式是用.txt文件保存的图片文本。用KNN实现手写识别的核心思想就是在训练集中找到一个欧氏距离最小的那个样本所属的类别,用该类别来确定未知样本的类别。
在识别中需要对图片进行向量化,因此需要一个图片转换成向量的函数:
# 将图像文本数据转换为向量 def img2vector(filename): returnVect = np.zeros((1,1024)) # returnVect = [] fr = open(filename) for i in range(32): read_oneline=fr.readline() for j in range(32): returnVect[0,i*32+j]=int(read_oneline[j]) return returnVect
然后就是在测试集上的精度测试:
def handwritingClassTest(filename,testFileName): # 1. 导入训练数据 hwLabels=[] # 标签集 trainingFileList = os.listdir(filename) # 获得文件列表 m = len(trainingFileList) trainingMat = np.zeros((m,1024)) for i in range(m): fileNameStr = trainingFileList[i] classNumStr = fileNameStr.split('_')[0] hwLabels.append(classNumStr) filename_all = filename+'/'+fileNameStr trainingMat[i, :] = img2vector(filename_all) # 2. 导入测试数据 testFileList=os.listdir(testFileName) mTest = len(testFileList) errorCount = 0.0 for i in range(mTest): fileNameStr = testFileList[i] classNumStr = int(fileNameStr.split('_')[0]) filename_all = testFileName + '/' + fileNameStr vectorUnderTest =img2vector(filename_all) classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels) print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)) if (classifierResult != classNumStr): errorCount += 1.0 print("\nthe total number of errors is: %d" % errorCount) print("\nthe total error rate is: %f" % (errorCount / float(mTest)))
上面测试功能中的classify0利用欧氏距离度量实现了最近类别查找:
def classify0(testVector, traningMat, hwLabels): row_num_train=traningMat.shape[0] testMat=np.zeros((row_num_train,1024)) for i in range(row_num_train): testMat[i,:]=testVector diff=testMat-traningMat diff=np.abs(diff) diff_row=np.sum(diff,axis=1) # 因为向量中的值不是1就是-1,平方后都是1,因此开根号后直接进行求和即可。 diff_min_index=np.argmin(diff_row) return int(hwLabels[diff_min_index])
最后,用 handwritingClassTest 函数测试一下就OK了。测试集946个,错误了13个,错误率为0.013742。自己可以试一下,代码有些地方写的不够规范,体谅下吧。