1 数据
--train_npy 训练数据 --test_npy 测试数据
.npy格式,01二值,尺寸32*32
数字包含0~9 ,文件名第一位为标签。数据不多,可以自己再搞一些
链接:https://pan.baidu.com/s/1tB-B5v1eQbYUwh1808bZFA
提取码:lucy
2 代码
from numpy import *
import numpy as np
import cv2
from os import listdir
import operator
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def handwritingClassTest():
Labels=[]
trainList=listdir('train_npy')
testList =listdir('test_npy')
train_num=len(trainList)
test_num = len(testList)
trainMat=np.zeros((train_num,1024))
for i in range(train_num):
fileNameStr=trainList[i]
trainMat[i, :]=np.load('./train_npy/{}'.format(fileNameStr))
fileStr=fileNameStr.split('.')[0]
num=int(fileStr.split('_')[0])
Labels.append(num)
right_count=0
for i in range(test_num):
fileNameStr=testList[i]
cur_testMat= np.load('./test_npy/{}'.format(fileNameStr))
fileStr=fileNameStr.split('.')[0]
num=int(fileStr.split('_')[0])
classifierresult=classify0(cur_testMat,trainMat,Labels,3)
print("get:%d,real:%d"%(classifierresult,num))
if (classifierresult==num):
right_count+=1
print("正确率:{}%".format(np.round(right_count*100/test_num,2)))
if __name__ == '__main__':
handwritingClassTest()