构造一个能识别数字 0 到 9 的基于 KNN 分类器的手写数字识别系统。
需要识别的数字是存储在文本文件中的具有相同的色彩和大小: 宽高是 32 像素 * 32 像素的黑白图像。
在控制台输入 mspaint,打开画图工具把像素调至32 x 32
选择单色位图
def getImageArr(path):
imageArr = cv2.imread(path)
imageArr = imageArr[:, :, 0] / 255
imageArr = imageArr.astype(np.int32)
where_0 = np.where(imageArr == 0)
where_1 = np.where(imageArr == 1)
imageArr[where_0] = 1
imageArr[where_1] = 0
return imageArr
解析得到01010的格式
解析题上给的数据集
def img2vector(filename):
returnVect = np.zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
写上k近邻算法
def classifyFun(inX, dataList, dataLabels, k):
dataListLen = dataList.shape[0]
difMat = np.tile(inX, (dataListLen, 1)) - dataList
sqDifMat = difMat ** 2
sqDifMatSum = sqDifMat.sum(axis = 1)
DifDist = sqDifMatSum ** 0.5
# 排序
DifDistSortIdx = DifDist.argsort()
classCount = {}
for i in range(k):
getLabel = dataLabels[DifDistSortIdx[i]]
classCount[getLabel] = classCount.get(getLabel, 0) + 1
classifyAns = sorted(classCount.items(), key = operator.itemgetter(1),reverse = True)
print(classifyAns)
return classifyAns[0][0]
def ClassTrain():
digLabels = []
trainingFileList = os.listdir("trainingDigits")
digCount = len(trainingFileList)
trainingMat = np.zeros((digCount, 32 * 32))
for i in range(digCount):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNum = int(fileStr.split('_')[0])
digLabels.append(classNum)
trainingMat[i, :] = img2vector('trainingDigits/'+fileNameStr)
return trainingMat, digLabels
最后进行一个调用
if __name__ == "__main__":
imageIn = getImageArr("5.bmp")
trainingMat, digLabels = ClassTrain()
classifierResult = classifyFun(imageIn.flatten(), trainingMat, digLabels, 10)
plt.imshow(imageIn)
print(classifierResult)