一、算法简要
“物以类聚人以群分”是生活的现实写照,knn就是让那些距离近的人或物归为一类。
问题定义:基于给定的一些示例(事物的属性features和该事物的分类class),对于某个特定或一系列事物的features,来对未定事物进行分类classifying。
一般把给出了事物features和class的集合叫做测试集TrainingSet,未给定即待定的事物集合称为测试集TestSet。
knn基本思想:输入没有标签或分类的新数据后,将新数据的特征features与测试集中的每一个数据特征进行比较,然后提取样本中k个特征最相似数据(最邻近的)的分类标签,选择k个最相似数据中出现次数最多的分类,作为新数据或新事物的分类。
二、算法一般流程伪代码
1.数据的准备
2.数据预处理:格式、是否要归一化(各个features之间的scale差距过大)等
3.分析数据:为了便于算法的实践,可以对原始数据集或预处理后的数据集进行一些实验性的统计和图示
4.训练算法:knn没有参数需要训练,但需要设置k值以及相似度计算方法
5.测试算法:基于TestSet进行计算
6.真实算法部署:可能需要进行语言抓换或平台部署
三、knn伪代码
1.计算已知类别数据集中的点与需要预测点之间的距离;
2.按照距离进行递增排序;
3.选择最近的k个点;
4.统计k个点中class最多的class
5.返回预测结果
四、实现
说明:knn主函数中,可以设置参数filename、testRatio和k值。这里把训练集和测试集都装在filename中,testRatio指定了testSet所占比重(前testRatio为测试数据),k则为选择最邻近邻居的个数。函数file2matrix将数据存储到matrix中,plot进行了散点图绘制,Norm系列实现了两种不同归一化方法,classify系列实现了不同的相似度计算方法。数据集下载点击这里
from numpy import *
import matplotlib.pyplot as plt
import matplotlib
import operator
def file2matrix(filename):
fr = open(filename)
arrayOLines = fr.readlines()
fr.close()
numberOfLines = len(arrayOLines)
returnMat = zeros((numberOfLines,3))
'''zeros这个函数是numpy命名空间中的,这里头文件中from import将numpy下所有变量命名都引入了,还有一种方式是import numpy但这个时候用zeros时需要加上numpy.zeros()。用from import方式是将所有相关变量名都引入了,所以对重构产生一定影响,不过一般还好'''
classLabelVector = []
index =0
for line in arrayOLines:
line = line.strip("\n")
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index=index+1
return returnMat, classLabelVector
def plot(dataMatrix, label):
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(dataMatrix[:,1],dataMatrix[:,2],0.01*array(dataMatrix),0.01*array(label))
plt.show()
'''(val-min)/(max-min)'''
def Norm0(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges= maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals,(m,1))
normDataSet = normDataSet/tile(ranges,(m,1))
return normDataSet
'''(val-mean)/std'''
def Norm1(dataSet):
meanVals = dataSet.mean(0)
stdVals = dataSet.std(0)
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(meanVals,(m,1))
normDataSet = normDataSet/tile(stdVals,(m,1))
return normDataSet
'''Euclidean Distance'''
def classify0(inX, trainSet, labels, k):
dataSetSize = trainSet.shape[0]
diffMat = tile(inX, (dataSetSize,1))-trainSet
sqDiffMat = diffMat**2
sqDistance = sqDiffMat.sum(axis=1)
distances = sqDistance**0.5
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
'''Manhattan Distance'''
def classify1(inX, trainSet, labels, k):
dataSetSize = trainSet.shape[0]
diffMat = tile(inX, (dataSetSize,1))-trainSet
DiffMat = abs(diffMat)
distances = DiffMat.sum(axis=1)
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
'''Chebyshev Distance'''
def classify2(inX, trainSet, labels, k):
dataSetSize = trainSet.shape[0]
diffMat = tile(inX, (dataSetSize,1))-trainSet
distances = diffMat.max(1)
sortedDistIndicies = distances.argsort()
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def knn(filename, testRatio, k):
dataSet, labels = file2matrix(filename)
normMat = Norm0(dataSet)
m = normMat.shape[0]
numTestVecs = int(m*testRatio)
errorCount = 1.0
for i in range(numTestVecs):
classifierResult = classify2(normMat[i,:],normMat[numTestVecs:m,:],labels[numTestVecs:m],k)
if(classifierResult != labels[i]):
errorCount+=1.0
print "Total error rate of test set is: %f"%(errorCount/float(numTestVecs))
filename="datingTestSet2.txt"
knn(filename,0.1,2)
knn(filename,0.1,3)
knn(filename,0.1,4)
knn(filename,0.1,5)
knn(filename,0.1,6)