1 准备数据:将图象转化为测试向量
为了使用前面使用过的分类器,将图象格式化为一个向量:将3232的二进制图像矩阵转化为11024的向量
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
2 测试算法:使用k-近邻算法识别手写数组
根据:
https://blog.csdn.net/weixin_43340821/article/details/122043911
已能将数据处理成分类器可以世界的格式。因此此处只需将数据输入到分类器,检测分类器的执行效果
编写函数测试分类器:
通过from os import listdir写入文件的起始部分,从os模块导入函数listdir,可以列出给定目录的文件名(返回指定的文件夹包含的文件或文件夹的名字的列表)
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
traingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i] # 获得每个文件的文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 从文件名解析出分类的数字(通过观察文件名结构得出)
hwLabels.append(classNumStr)
traingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, traingMat, hwLabels, 3)
print("the classifier came back with: %d,the real answer is %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print("the total number of errors is: %d" % errorCount)
print("the total error rate is: %f" % (errorCount / float(mTest)))
kNN.handwritingClassTest()
完整代码
from numpy import *
from os import listdir
import numpy as np
import operator
def classify0(inX, dataSet, labels, k): # inX为用于分类的输入向量,dataSet为输入的训练样本集,labels为标签向量,k为选择最近邻居的数目
dataSetSize = dataSet.shape[0] # 获取行数
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
# tile(inx, (datasetsize, 1)):将矩阵inx纵向复制datasetsize份(成了datasetsize行),再横向复制一份
# title(对象矩阵,([m,] k )):title函数是个复制函数,作用是将对象矩阵作为一个单元进行横向和纵向复制,形成一个m*k的矩阵。注意的是,当()只有一个数值时,只会横向复制。
# 此处是将变量集复制为和训练集结构一样的矩阵,再和训练集进行矩阵计算,得到每个轴距。
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1) # 将矩阵的各行相加
# sum(对象,axis=0/1)或者对象.sum(axis=0/1):axis=0代表对象同列元素相加;axis=1 同行的元素相加,每行返回一个值;不带axis则表示所有元素相加
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort() # distances.argsort():对distances进行从小到大排序,返回对应的索引。如排序后得到[3,2,0,1],最小的是排序对象的索引号为3的元素。
classCount = {} # 定义空字典
for i in range(k): # 选择距离最小的k个点
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# classcount[voteilabel]:统计classcount中每个标签出现的次数
# get(voteilabel,0):原型为get(key,默认值),作用是获取key对应的值,如果不存在key,则新增key,值为默认值
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 字典items方法:将字典中所有项以列表方式返回;iteritems方法作用:与items方法相比作用大致相同,只是它的返回值不是列表,而是一个迭代器
# 按照每个key的值对每对keyvalue进行降序排序
# sorted(对象,排序值,reverse=True/False) sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
# ist.sort(*, key=None, reverse=False) 此方法会对列表进行原地排序,默认排序为升序。
# key 指定带有一个参数的函数,用于从每个列表元素中提取比较键 (例如 key=str.lower),可以是匿名函数,也可是自定义的函数。 key函数排序的过程中,对应于列表中每一项的键会被计算一次。 默认值 None 表示直接对列表项排序而不计算一个单独的键值。
# operator模块提供的itemgetter函数用于获取对象的哪些维的数据
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits')
m = len(trainingFileList)
traingMat = zeros((m, 1024))
for i in range(m):
fileNameStr = trainingFileList[i] # 获得每个文件的文件名
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0]) # 从文件名解析出分类的数字(通过观察文件名结构得出)
hwLabels.append(classNumStr)
traingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, traingMat, hwLabels, 3)
print("the classifier came back with: %d,the real answer is %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print("the total number of errors is: %d" % errorCount)
print("the total error rate is: %f" % (errorCount / float(mTest)))