背景
对于需要识别的数字使用图形处理软件,处理成具有相同的色彩和大小:
宽高是32像素x32像素。尽管采用本文格式存储图像不能有效地利用内存空间
但是为了方便理解,将图片转换为文本格式。
数字图片是32x32的二进制图像,为了方便计算,
将32x32的二进制图像转换为1x1024的向量(32*32=1024)。
对于sklearn的KNeighborsClassifier输入可以是矩阵,不用一定转换为向量,
不过为了跟写的k-近邻算法分类器对应上,这里做了向量化处理。
然后构建kNN分类器,利用分类器做预测。
1. KNN算法,分类器代码
import numpy as np
import operator
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as kNN
"""
函数说明:将32x32的二进制图像转换为1x1024向量。
Parameters:
filename - 文件名
Returns:
returnVect - 返回的二进制图像的1x1024向量
"""
def img2vector(filename):
#创建1x1024零向量
returnVect = np.zeros((1, 1024))
#打开文件
fr = open(filename)
#按行读取
for i in range(32):
#读一行数据
lineStr = fr.readline()
#每一行的前32个元素依次添加到returnVect中
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
#返回转换后的1x1024向量
return returnVect
"""
函数说明:手写数字分类测试
Parameters:
无
Returns:
无
"""
def handwritingClassTest():
#测试集的Labels
hwLabels = []
#返回trainingDigits目录下的文件名
trainingFileList = listdir('trainingDigits')
#返回文件夹下文件的个数
m = len(trainingFileList)
#初始化训练的Mat矩阵,测试集
trainingMat = np.zeros((m, 1024))
#从文件名中解析出训练集的类别
for i in range(m):
#获得文件的名字
fileNameStr = trainingFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#将获得的类别添加到hwLabels中
hwLabels.append(classNumber)
#将每一个文件的1x1024数据存储到trainingMat矩阵中
trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
#构建kNN分类器
neigh = kNN(n_neighbors = 3, algorithm = 'auto')
#拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签
neigh.fit(trainingMat, hwLabels)
#返回testDigits目录下的文件列表
testFileList = listdir('testDigits')
#错误检测计数
errorCount = 0.0
#测试数据的数量
mTest = len(testFileList)
#从文件中解析出测试集的类别并进行分类测试
for i in range(mTest):
#获得文件的名字
fileNameStr = testFileList[i]
#获得分类的数字
classNumber = int(fileNameStr.split('_')[0])
#获得测试集的1x1024向量,用于训练
vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
#获得预测结果
# classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
classifierResult = neigh.predict(vectorUnderTest)
print("分类返回结果为%d\t真实结果为%d" % (classifierResult, classNumber))
if(classifierResult != classNumber):
errorCount += 1.0
print("总共错了%d个数据\n错误率为%f%%" % (errorCount, errorCount/mTest * 100))
"""
函数说明:main函数
Parameters:
无
Returns:
无
"""
if __name__ == '__main__':
handwritingClassTest()
运行结果:
分类返回结果为1 真实结果为1
...
...
...
...
...
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
分类返回结果为9 真实结果为9
总共错了12个数据
错误率为1.268499%
参考博客:https://blog.csdn.net/c406495762/article/details/75172850