机器学习算法一——k-近邻算法(3)
使用k-近邻算法的手写识别系统
需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素x32像素的黑白图像。尽管采用文本格式存储图像不能有效地利用内存空间,但是为了方便理解,我们还是将图像转换为文本格式。
1、准备数据:将图像转换为测试向量
目录trainingDigits中包含了大约2000个例子(每个数字大约200个样本);目录testDigits中包含了大约900个测试数据。
我们将一个32x32的二进制图像矩阵转换为1x1024的向量,这样前两节使用的分类器就可以处理数字图像信息了。
##将图像转换为向量
def img2vector(filename):
returnvect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnvect[0,32*i+j] = int(lineStr[j])
return returnvect
2、测试算法:使用k-近邻算法识别手写数字
from os import listdir
必须确保将这句代码写入文件的起始部分
##手写数字识别系统的测试代码
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m,1024))#m为训练集中的数字个数,每个数字由1024个元素组成
#从文件名解析分类数字
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #取.txt文件的文件名
classNumStr = int(fileStr.split('_')[0]) #取某个数字,去除编号
hwLabels.append(classNumStr) #classNumstr为对应标签
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest,trainingMat, hwLabels,3)
print("the classifier came back with: %d, the real answer is: %d" %(classifierResult,classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print("\nthe total number of errors is: %d" % errorCount)
print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
执行结果:
>>>the classifier came back with: 0, the real answer is: 0
the classifier came back with: 0, the real answer is: 0
the classifier came back with: 0, the real answer is: 0
the classifier came back with: 0, the real answer is: 0
……
the classifier came back with: 6, the real answer is: 6
the classifier came back with: 6, the real answer is: 6
the classifier came back with: 7, the real answer is: 7
the classifier came back with: 7, the real answer is: 7
……
the classifier came back with: 9, the real answer is: 9
the classifier came back with: 9, the real answer is: 9
the total number of errors is: 11
the total error rate is: 0.011628
可见,错误率为1.2%。
改变变量k的值、修改函数handwritingClassTest随机选取训练样本、改变训练样本的数目,都会对k-近邻算法的错误率产生影响。
缺点:实际使用这个算法时,算法的执行效率并不高。因为算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次。此外,还需为测试向量准备2MB存储空间。
k决策树就是k-近邻算法的优化版,可节省大量的计算开销。
(1)os.listdir()方法
功能:用于返回指定的文件夹包含的文件或文件夹的名字的列表。这个列表以字母顺序。
语法:os.listdir(path)
参数:需要列出的目录路径
返回值:返回指定路径下的文件和文件夹列表。
k-近邻算法总结
k-近邻算法是分类数据最简单有效的算法。
k-近邻算法是基于实例的学习,使用此算法时我们必须有接近实际数据的训练样本数据。
k-近邻算法必须保存全部数据集,如果训练数据集很大,必须使用大量的存储空间。
此外,由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。
k-近邻算法的另一个缺陷:无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。
下一章使用概率测量方法处理分类问题。