from numpy import * import operator def readfile(filename): fr =open(filename) arrayOLines = fr.readlines() numbersOFLines = len(arrayOLines) returnMat = zeros((numbersOFLines, 3)) classLabelVector = [] index = 0 for line in arrayOLines: line = line.strip() listFromLine = line.split('\t') returnMat[index, :] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector def autoNorm(dataSet): minVals = dataSet.min(0) maxVals = dataSet.max(0) ranges = maxVals - minVals normDataSet = zeros(shape(dataSet)) m = dataSet.shape[0] normDataSet = dataSet - tile(minVals,(m,1)) normDataSet = normDataSet/tile(ranges,(m,1)) return normDataSet,ranges,minVals A,B =readfile('datingTestSet2.txt') print(A,B) A,C,D = autoNorm(A) #A 为normmat C为ranges D为minvals import matplotlib import matplotlib.pyplot as plt fig = plt.figure() ax = fig.add_subplot(221) ax.scatter(A[:,1],A[:,2]) ax = fig.add_subplot(222) ax.scatter(A[:,1],A[:,2],15.0*array(B),15.0*array(B)) ax = fig.add_subplot(223) ax.scatter(A[:,0],A[:,1]) ax = fig.add_subplot(224) ax.scatter(A[:,0],A[:,1],15.0*array(B),15.0*array(B)) plt.show() print('-----------------------------') def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def datingClassTest(): hoRatio = 0.10 datingDataMat,datingLables = readfile('datingTestSet2.txt') normMat , ranges , minVals = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLables[numTestVecs:m],3) print('分类器反馈:%d,实际是:%d'%(classifierResult,datingLables[i])) if (classifierResult != datingLables[i]) : errorCount += 1.0 print('总的差错率是:%f'%(errorCount/float(numTestVecs))) datingClassTest()
以上是具体代码,书上给的是基于python2,本人给的是基于python3。
下面写点总结:
KNN代码的核心部分是定义的classify0函数,其伪代码如下:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知数据集中的点与当前点之间的距离
(2) 按照距离递增次序排序
(3) 选取与当前点距离最小的k个点
(4) 确定前k个点所在类别的出现的概率
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类
代码中inx是用于分类的数据向量,dataset 是数据集,labels是数据集对应的标签,k是分类数