这里的数字存储在一个文本文件中,是由32*32个0或1组成的数字矩阵,背景用0表示,数字用1表示
from numpy import *
import operator
import os
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedIndex = distances.argsort()
classCount = {}
for i in range(k):
label = labels[sortedIndex[i]]
classCount[label] = classCount.get(label,0) + 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
'''
准备数据,将32*32的文本文件存储成1*1024的向量
:param filename:
:return:
'''
returnVet = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVet[0,i*32+j] = int(lineStr[j])
return returnVet
def handwritingClassTest():
hwLabels = []
trainingFileList = os.listdir(r'f:\python\trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector(r'f:\python\trainingDigits\%s' % fileNameStr)
testFileList = os.listdir(r'f:\python\testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector(r'f:\python\testDigits\%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels,3)
print("the classifier came back with:%d, the real answer is:%d" % (classifierResult,classNumStr))
if(classifierResult != classNumStr):errorCount += 1.0
print("the total error rate is %f" % (errorCount/float(mTest)))
def classifyHandwriting():
while True:
filename = input("give your filename of the number:")
if filename == '':break
#获取需要识别的向量
vector = img2vector(filename)
#得到特征集合和目标变量集合
hwLabels = []
trainingFileList = os.listdir(r'f:\python\trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector(r'f:\python\trainingDigits\%s' % fileNameStr)
#分类
result = classify0(vector, trainingMat,hwLabels, 3)
print("the number is: %d" % result)
classfyHandwriting是我添加的函数,可以识别任意给的文件,当然训练集和测试集用的是《机器学习实战》的配套数据