1 '''
2 Created on Nov 06, 20173 kNN: k Nearest Neighbors4
5 Input: inX: vector to compare to existing dataset (1xN)6 dataSet: size m data set of known vectors (NxM)7 labels: data set labels (1xM vector)8 k: number of neighbors to use for comparison (should be an odd number)9
10 Output: the most popular class label11
12 @author: Liu Chuanfeng13 '''
14 importoperator15 importnumpy as np16 importmatplotlib.pyplot as plt17 from os importlistdir18
19 defclassify0(inX, dataSet, labels, k):20 dataSetSize =dataSet.shape[0]21 diffMat = np.tile(inX, (dataSetSize,1)) -dataSet22 sqDiffMat = diffMat ** 2
23 sqDistances = sqDiffMat.sum(axis=1)24 distances = sqDistances ** 0.5
25 sortedDistIndicies =distances.argsort()26 classCount ={}27 for i inrange(k):28 voteIlabel =labels[sortedDistIndicies[i]]29 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
30 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse =True)31 returnsortedClassCount[0][0]32
33 #数据预处理,将文件中数据转换为矩阵类型
34 deffile2matrix(filename):35 fr =open(filename)36 arrayLines =fr.readlines()37 numberOfLines =len(arrayLines)38 returnMat = np.zeros((numberOfLines, 3))39 classLabelVector =[]40 index =041 for line inarrayLines:42 line =line.strip()43 listFromLine = line.split('\t')44 returnMat[index,:] = listFromLine[0:3]45 classLabelVector.append(int(listFromLine[-1]))46 index += 1
47 returnreturnMat, classLabelVector48
49 #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重
50 defautoNorm(dataSet):51 maxVals =dataSet.max(0)52 minVals =dataSet.min(0)53 ranges = maxVals -minVals54 m =dataSet.shape[0]55 normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1))56 returnnormDataSet, ranges, minVals57
58 #约会网站测试代码
59 defdatingClassTest():60 hoRatio = 0.10
61 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')62 normMat, ranges, minVals =autoNorm(datingDataMat)63 m =normMat.shape[0]64 numTestVecs = int(m *hoRatio)65 errorCount = 0.0
66 for i inrange(numTestVecs):67 classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)68 print('theclassifier came back with: %d, the real answer is: %d' %(classifyResult, datingLabels[i]))69 if ( classifyResult !=datingLabels[i]):70 errorCount += 1.0
71 print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100))72
73 #约会网站预测函数
74 defclassifyPerson():75 resultList = ['not at all', 'in small doses', 'in large doses']76 percentTats = float(input("percentage of time spent playing video games?"))77 ffMiles = float(input("frequent flier miles earned per year?"))78 iceCream = float(input("liters of ice cream consumed per year?"))79 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')80 normMat, ranges, minVals =autoNorm(datingDataMat)81 inArr =np.array([ffMiles, percentTats, iceCream])82 classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)83 print ("You will probably like this persoon:", resultList[classifyResult - 1])84
85
86 #手写识别系统#============================================================================================================
87 #数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024
88 defimg2vector(filename):89 returnVect = np.zeros((1, 1024))90 fr =open(filename)91 for i in range(32):92 lineStr =fr.readline()93 for j in range(32):94 returnVect[0, 32*i+j] =int(lineStr[j])95 returnreturnVect96
97 #手写数字识别系统测试代码
98 defhandwritingClassTest():99 hwLabels =[]100 trainingFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits')101 m =len(trainingFileList)102 trainingMat = np.zeros((m, 1024))103 for i in range(m): #|
104 fileNameStr = trainingFileList[i] #|
105 fileName = fileNameStr.split('.')[0] #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中
106 classNumber = int(fileName.split('_')[0]) #|
107 hwLabels.append(classNumber) #|
108 trainingMat[i,:] = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits\\%s' % fileNameStr) #变换矩阵形状: from 32*32 to 1*1024
109 testFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits')110 errorCount = 0.0
111 mTest =len(testFileList)112 for i in range(mTest): #同训练集
113 fileNameStr =testFileList[i]114 fileName = fileNameStr.split('.')[0]115 classNumber = int(fileName.split('_')[0])116 vectorUnderTest = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\%s' %fileNameStr)117 classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #计算欧氏距离并分类,返回计算结果
118 print ('The classifier came back with: %d, the real answer is: %d' %(classifyResult, classNumber))119 if (classifyResult !=classNumber):120 errorCount += 1.0
121 print ('The total number of errors is: %d' %(errorCount))122 print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100))123
124 #Simple unit test of func: file2matrix()
125 #datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
126 #print (datingDataMat)
127 #print (datingLabels)
128
129 #Usage of figure construction of matplotlib
130 #fig=plt.figure()
131 #ax = fig.add_subplot(111)
132 #ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
133 #plt.show()
134
135 #Simple unit test of func: autoNorm()
136 #normMat, ranges, minVals = autoNorm(datingDataMat)
137 #print (normMat)
138 #print (ranges)
139 #print (minVals)
140
141 #Simple unit test of func: img2vector
142 #testVect = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\0_13.txt')
143 #print (testVect[0, 32:63] )
144
145 #约会网站测试
146 datingClassTest()147
148 #约会网站预测
149 classifyPerson()150
151 #手写数字识别系统预测
152 handwritingClassTest()