手写数字识别系统的测试代码
def handwritingClassTest():
hwLabels=[]
trainingFileList=listdir('trainingDigits')#将trainingDigits目录中的文件内容存储在训练列表"trainingFileList"中
m=len(trainingFileList)#得到目录中有多少文件,并将其存储在变量m中
trainingMat=zeros((m,1024))#创建一个m行1024列的训练矩阵,该矩阵的每行数据存储一个图像
for i in range(m):
fileNameStr=trainingFileList[i]#依次获取每个文件的名字
fileStr=fileNameStr.split('.')[0]#将文件名分割成两部分并只取第一部分。第二部分是后缀格式,我们不需要它
classNumStr=int(fileStr.split('_')[0])#获取类名(1-9)
hwLabels.append(classNumStr)#将类名依次添加到hwLabels里面
trainingMat[i,:]=img2vector('trainingDigits/%s'%fileNameStr)#打开目录中的每一个文件并添加到trainingMat这个矩阵中去
#print trainingMat
#下面对testDIGITS执行类似的操作
testFileList=listdir('testDigits')#将testDigits目录中的文件内容存储在测试列表"testFileList"中
errorCount=0.0
mTest=len(testFileList)#得到目录中有多少文件,并将其存储在变量mTest中
for i in range(mTest):
fileNameStr=testFileList[i]#依次获取每个文件的名字
fileStr=fileNameStr.split('.')[0]
classNumStr=int(fileStr.split('_')[0])
vectorUnderTest=img2vector('testDigits/%s'%fileNameStr)#这里和上面不同之处在于没有将读取的文件载入到矩阵中,而是要利用分类器进行测试,继续向下看
classifierResult=classify0(vectorUnderTest,trainingMat,hwLabels,3)#调用分类器进行测试
print "the classifier came back with: %d,the real answer is : %d"%(classifierResult,classNumStr)
if (classifierResult!=classNumStr):errorCount+=1.0
print "\nthe total number of errors is:%d"%errorCount
print "\nthe total error rate is :%f"%(errorCount/float(mTest))