python识别手写数字knn_KNN分类算法实现手写数字识别

1 #!/usr/bin/python

2 #coding=utf-8

3 #########################################

4 #kNN: k Nearest Neighbors

5

6 #参数: inX: vector to compare to existing dataset (1xN)

7 #dataSet: size m data set of known vectors (NxM)

8 #labels: data set labels (1xM vector)

9 #k: number of neighbors to use for comparison

10

11 #输出: 多数类

12 #########################################

13

14 from numpy import *

15 importoperator16 importos17

18

19 #KNN分类核心方法

20 defkNNClassify(newInput, dataSet, labels, k):21 numSamples = dataSet.shape[0] #shape[0]代表行数

22

23 ## step 1: 计算欧式距离

24 #tile(A, reps): 将A重复reps次来构造一个矩阵

25 #the following copy numSamples rows for dataSet

26 diff = tile(newInput, (numSamples, 1)) - dataSet #Subtract element-wise

27 squaredDiff = diff ** 2 #squared for the subtract

28 squaredDist = sum(squaredDiff, axis = 1) #sum is performed by row

29 distance = squaredDist ** 0.5

30

31 ## step 2: 对距离排序

32 #argsort()返回排序后的索引

33 sortedDistIndices =argsort(distance)34

35 classCount = {} #定义一个空的字典

36 for i inxrange(k):37 ## step 3: 选择k个最小距离

38 voteLabel =labels[sortedDistIndices[i]]39

40 ## step 4: 计算类别的出现次数

41 #when the key voteLabel is not in dictionary classCount, get()

42 #will return 0

43 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1

44

45 ## step 5: 返回出现次数最多的类别作为分类结果

46 maxCount =047 for key, value inclassCount.items():48 if value >maxCount:49 maxCount =value50 maxIndex =key51

52 returnmaxIndex53

54 #将图片转换为向量

55 defimg2vector(filename):56 rows = 32

57 cols = 32

58 imgVector = zeros((1, rows *cols))59 fileIn =open(filename)60 for row inxrange(rows):61 lineStr =fileIn.readline()62 for col inxrange(cols):63 imgVector[0, row * 32 + col] =int(lineStr[col])64

65 returnimgVector66

67 #加载数据集

68 defloadDataSet():69 ## step 1: 读取训练数据集

70 print "---Getting training set..."

71 dataSetDir = 'E:/KNNCase/digits/'

72 trainingFileList = os.listdir(dataSetDir + 'trainingDigits') #加载测试数据

73 numSamples =len(trainingFileList)74

75 train_x = zeros((numSamples, 1024))76 train_y =[]77 for i inxrange(numSamples):78 filename =trainingFileList[i]79

80 #get train_x

81 train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' %filename)82

83 #get label from file name such as "1_18.txt"

84 label = int(filename.split('_')[0]) #return 1

85 train_y.append(label)86

87 ## step 2:读取测试数据集

88 print "---Getting testing set..."

89 testingFileList = os.listdir(dataSetDir + 'testDigits') #load the testing set

90 numSamples =len(testingFileList)91 test_x = zeros((numSamples, 1024))92 test_y =[]93 for i inxrange(numSamples):94 filename =testingFileList[i]95

96 #get train_x

97 test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' %filename)98

99 #get label from file name such as "1_18.txt"

100 label = int(filename.split('_')[0]) #return 1

101 test_y.append(label)102

103 returntrain_x, train_y, test_x, test_y104

105 #手写识别主流程

106 deftestHandWritingClass():107 ## step 1: 加载数据

108 print "step 1: load data..."

109 train_x, train_y, test_x, test_y =loadDataSet()110

111 ## step 2: 模型训练.

112 print "step 2: training..."

113 pass

114

115 ## step 3: 测试

116 print "step 3: testing..."

117 numTestSamples =test_x.shape[0]118 matchCount =0119 for i inxrange(numTestSamples):120 predict = kNNClassify(test_x[i], train_x, train_y, 3)121 if predict ==test_y[i]:122 matchCount += 1

123 accuracy = float(matchCount) /numTestSamples124

125 ## step 4: 输出结果

126 print "step 4: show the result..."

127 print 'The classify accuracy is: %.2f%%' % (accuracy * 100)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值