1 #!/usr/bin/python
2 #coding=utf-8
3 #########################################
4 #kNN: k Nearest Neighbors
5
6 #参数: inX: vector to compare to existing dataset (1xN)
7 #dataSet: size m data set of known vectors (NxM)
8 #labels: data set labels (1xM vector)
9 #k: number of neighbors to use for comparison
10
11 #输出: 多数类
12 #########################################
13
14 from numpy import *
15 importoperator16 importos17
18
19 #KNN分类核心方法
20 defkNNClassify(newInput, dataSet, labels, k):21 numSamples = dataSet.shape[0] #shape[0]代表行数
22
23 ## step 1: 计算欧式距离
24 #tile(A, reps): 将A重复reps次来构造一个矩阵
25 #the following copy numSamples rows for dataSet
26 diff = tile(newInput, (numSamples, 1)) - dataSet #Subtract element-wise
27 squaredDiff = diff ** 2 #squared for the subtract
28 squaredDist = sum(squaredDiff, axis = 1) #sum is performed by row
29 distance = squaredDist ** 0.5
30
31 ## step 2: 对距离排序
32 #argsort()返回排序后的索引
33 sortedDistIndices =argsort(distance)34
35 classCount = {} #定义一个空的字典
36 for i inxrange(k):37 ## step 3: 选择k个最小距离
38 voteLabel =labels[sortedDistIndices[i]]39
40 ## step 4: 计算类别的出现次数
41 #when the key voteLabel is not in dictionary classCount, get()
42 #will return 0
43 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
44
45 ## step 5: 返回出现次数最多的类别作为分类结果
46 maxCount =047 for key, value inclassCount.items():48 if value >maxCount:49 maxCount =value50 maxIndex =key51
52 returnmaxIndex53
54 #将图片转换为向量
55 defimg2vector(filename):56 rows = 32
57 cols = 32
58 imgVector = zeros((1, rows *cols))59 fileIn =open(filename)60 for row inxrange(rows):61 lineStr =fileIn.readline()62 for col inxrange(cols):63 imgVector[0, row * 32 + col] =int(lineStr[col])64
65 returnimgVector66
67 #加载数据集
68 defloadDataSet():69 ## step 1: 读取训练数据集
70 print "---Getting training set..."
71 dataSetDir = 'E:/KNNCase/digits/'
72 trainingFileList = os.listdir(dataSetDir + 'trainingDigits') #加载测试数据
73 numSamples =len(trainingFileList)74
75 train_x = zeros((numSamples, 1024))76 train_y =[]77 for i inxrange(numSamples):78 filename =trainingFileList[i]79
80 #get train_x
81 train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' %filename)82
83 #get label from file name such as "1_18.txt"
84 label = int(filename.split('_')[0]) #return 1
85 train_y.append(label)86
87 ## step 2:读取测试数据集
88 print "---Getting testing set..."
89 testingFileList = os.listdir(dataSetDir + 'testDigits') #load the testing set
90 numSamples =len(testingFileList)91 test_x = zeros((numSamples, 1024))92 test_y =[]93 for i inxrange(numSamples):94 filename =testingFileList[i]95
96 #get train_x
97 test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' %filename)98
99 #get label from file name such as "1_18.txt"
100 label = int(filename.split('_')[0]) #return 1
101 test_y.append(label)102
103 returntrain_x, train_y, test_x, test_y104
105 #手写识别主流程
106 deftestHandWritingClass():107 ## step 1: 加载数据
108 print "step 1: load data..."
109 train_x, train_y, test_x, test_y =loadDataSet()110
111 ## step 2: 模型训练.
112 print "step 2: training..."
113 pass
114
115 ## step 3: 测试
116 print "step 3: testing..."
117 numTestSamples =test_x.shape[0]118 matchCount =0119 for i inxrange(numTestSamples):120 predict = kNNClassify(test_x[i], train_x, train_y, 3)121 if predict ==test_y[i]:122 matchCount += 1
123 accuracy = float(matchCount) /numTestSamples124
125 ## step 4: 输出结果
126 print "step 4: show the result..."
127 print 'The classify accuracy is: %.2f%%' % (accuracy * 100)