#-*- coding: UTF-8 -*-
from numpy import *
import operator
from os import listdir
def classify0(inX,dataSet,labels,k):
dataSetSize = dataSet.shape[0] #获得训练集的第二维长度即行数,shape[1]为获得第一维长度即列数
diffMat = tile(inX,(dataSetSize,1)) - dataSet #将输入的待检测向量行数扩展为已有行数的datasize倍,列数不变,此时向量扩展所得矩阵的行列数与训练集矩阵的行列数相等,可以做差
sqDiffMat = diffMat ** 2 #做差所得矩阵每个元素取平方
sqDistances = sqDiffMat.sum(axis = 1) #将矩阵按行相加,得到每行的元素的平方和(axis=0为按列相加)
distances = sqDistances ** 0.5 #开平方,得到待测向量集与训练集各行元素的距离
sortedDistIndicies = distances.argsort() #argsort()函数将距离按从小到大的顺序排序,并返回排序后元素在原来未排序列表中的索引值(未排序元素行与label中的元素一一对应)
classCount = {} #定义一个新的存储类别数目的空字典
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]#通过sortedDistIndices中第i个元素的索引,找对应labels中的类别(如第0个元素的索引值为120,其在labels中的类别为1,说明该向量与1最接近)
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #将得到的某类别的个数存入字典,get函数作用是:如果该类别存在,则取出已有数目,不存在则置0;然后与1作和
sortedClassCount = sorted(classCount.iteritems(),#将上面得到的字典元素进行排序,key=operator.itemgetter(1)将字典中的元素按值排序,即对于{A:1,B:2},按"1","2"排序,
key = operator.itemgetter(1),reverse = True ) #若itemgetter(0),则按"A","B"排序,reverse=True:按降序排序
return sortedClassCount[0][0] #排序后得到的新字典中,类别数最大的类与该类对于的值在最前面,第一个0意思是返回该字典中第一个键值,第二个0为返回该键值的键,即类别
def img2vector(filename):
returnVec = zeros((1,1024))#初始化返回向量,即构造一个1行1024列的矩阵
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVec[0,32 * i + j] = int(lineStr[j])#将读取到的文件中1024个数值存入到returnVec中
return returnVec
def handwritingClassTet():
hwLabels = []#定义空列表,以存放类别
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m,1024))#初始化训练矩阵,得到一个m行1024列的0矩阵,m为训练集文件数
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]#按.切分文件名,并取第一个切片,如9_24.txt,切分后取出9_24
classNumStr = int(fileStr.split('_')[0])#对上面的切片再切分,取第一个切片,如9_24,切分后取出9
hwLabels.append(classNumStr)#放入类别列表
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#调用图片向量转换函数,将其转换成1行1024列的向量
testFileList = listdir('testDigits')
errorCount = 0.0 #初始化误差值
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUndertest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUndertest,trainingMat,hwLabels,3)#调用训练算法
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult,classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0
print "\nthe total number of errors is %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount / float(mTest))
print handwritingClassTet()
欢迎指正