前言
前两篇文章《约会对象魅力程度分类》和《使用sklearn中的KNN算法》已经把KNN算法相关内容介绍完毕,从本节开始再举几个例子加深对KNN的理解。
本节主要记录MLiA中手写识别系统。
转载请注明出处:http://blog.csdn.net/rosetta
KNN手写识别系统
这节内容也是很简单的,不同之处在于原始数据表示方法不同而已。每个样本都是32行*32列=1024大小,所以使用一个1024大小的list存放一个样本。
训练数据trainingDigits目录,共1934个,测试数据testDigits目录,共946个。每个数字大概200个样本,样本的标签通过文件名标识,文件名下划线左边的就是该样本实际的数字,即标签了。
该程序先使用训练数据创建分类算法,然后使用testDigits数据进行测试。
具体细节可对照代码中的注释进行学习,注释已经很详细了。
#以下开始学习2.3节:手写识别系统
#该函数用于读取样本文字到list中。
#因为每个样本都是32行*32列=1024大小。所以这个返回什returnVect就是一个1024大小的list,存放了样本的数据。
def img2vector(filename):
returnVect = np.zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def img2vector_test():
testVector = img2vector('testDigits/0_13.txt')
print(testVector[0, 0:31])
#2.3.2 使用k-近邻算法识别手写数字。
#把2.1和2.2节真正学会了,这节其实是很简单的。不同之处在于原始数据表示方法不同而已。
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')#获取训练数据文件名列表,一共1934个
m = len(trainingFileList)#m=1934
trainingMat = np.zeros((m,1024))#创建1934行,每行1024列。
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])#每个样本的标签通过文件名标识,文件名下划线左边的就是该样本实际的数字,所以程序先解析这些数据,把它们存在hwLabels中。
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#获取训练样本的属性特征。
#以下开始获取测试样本的属性特征和标签,和上述获取训练数据时是一样的,只不过本实例是一条一条测试,
#其实可使用sklearn KNN相关算法先把测试样本的属性和标签存下来,然后调用相应接口即可。
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
main中调用的地方:
#以下开始学习2.3节:手写识别系统
#8、2.3.1将图像转换为测试向量。
# img2vector_test()
#9、2.3.2 使用k-近邻算法识别手写数字。
handwritingClassTest()
运行结果如下,错误率大概1%多点。
至此KNN手写识别系统介绍完毕。
如有疑问之处欢迎加我微信交流,共同进步!请备注“CSDN博客”