算法概述
算法定义:k近邻采用计算预测点与样本数据特征的距离,选取k个距离最近的样本标签(这个标签其实就是数据的分类,这个数据属于哪个类别,比如某一条电影数据前面有好几个特征值,后边这个标签就是标记这条电影数据属于那种类型的电影,例如属于爱情类电影,这个标签就是爱情),找到k个距离最近的样本标签后,统计这k个样本重出现次数最多的那个分类。
算法条件:需要有样本数据,由于需要计算距离,样本数据特征之间的数值大小相差不应该太大,比如某个特征数值为10000,另外一个特征数值为0.11,这样计算出来的距离和第一个特征正相关,其他特征被屏蔽,这样的样本数据就需要对样本数据进行归一化处理,可以将数据转化为0-1之间的数值
可以使用这个公式: (x-min)/(max-min)
- 式中的x表示当前被归一化的数据
- 式子中的min是这该特征数据的最小值
- max是这个特征数据最大值
通过上面的式子,我们就能够把这个特征数据进行归一化处理
算法实现过程:
- 准备样本数据,样本数据最好能够有一定的格式,这样在进行程序处理的时候就能快速的处理了
- 样本数据归一化(如果需要的话)
- 输入待分类数据
- 计算带分类数据与各条样本数据的距离
- 找到前k个距离最近的样本数据
- 找到这个k个样本数据中,出现次数最多的标签
找到计算次数最多的标签,即为这个待分类数据的预测分类
算法实例
下面介绍一个书中的实例,利用样本数据进行手写数字识别
这个例子中的样本数据是二值化的数据
通过0-9的10个手写样本数据进行加载,计算与待分类样本数据距离,通过距离最近的k个样本,估计待测样本数据的预测值。
在计算中大量使用了numpy这个python的科学计算包,进行矩阵的运算
from testpackage import test
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
from os import listdir
def classify0(inx,dataset,lables,k):
datasetSize = dataset.shape[0]
diffdataset = tile(inx,(datasetSize,1))-dataset
sqdiffdataset = diffdataset**2
sqdiffdistances = sqdiffdataset.sum(axis=1)
distances = sqdiffdistances**0.5
sorteddistance = distances.argsort()
classcount = {}
for i in range(k):
votelable = lables[sorteddistance[i]]
classcount[votelable] = classcount.get(votelable,0) +1
sortedClassCount = sorted(classcount.items(),key=operator.itemgetter(0),reverse = True)
return sortedClassCount[0][0]
def datingfileToMatrix(filename):
fr = open(filename)
datinglines = fr.readlines()
datinglength = len(datinglines)
datingMat = zeros((datinglength,3))
classLablesV = []
index = 0
for line in datinglines:
row = line.strip()
listfromLine = row.split('\t')
datingMat[index,:] = listfromLine[0:3]
classLablesV.append(int(listfromLine[-1]))
index+=1
return datingMat,classLablesV
def autoNorm(dataSet):
minvaules = dataSet.min(0)
maxvalues = dataSet.max(0)
rangevalue = maxvalues - minvaules
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minvaules,(m,1))
normDataSet = normDataSet/tile(rangevalue,(m,1))
return normDataSet,rangevalue,minvaules
def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat,datingLabels = datingfileToMatrix('datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
print(errorCount)
def img2vector(filename):
file = open(filename)
returnvector = zeros((1,1024))
for i in range(32):
lineStr = file.readline()
for j in range(32):
returnvector[0,i*32+j] = int(lineStr[j])
return returnvector
def handwritingclassTest():
hwlables = []
trainingfiles = listdir('digits/trainingDigits')
m = len(trainingfiles)
trainmat = zeros((m,1024))
for i in range(m):
filename = trainingfiles[i]
fileStr = filename.split('.')[0]
rightnumber = int(fileStr.split('_')[0])
hwlables.append(rightnumber)
trainmat[i,:] = img2vector('digits/trainingDigits/{0}'.format(filename))
testfilelist = listdir('digits/testDigits')
errorcount = 0.0
mTest = len(testfilelist)
for i in range(mTest):
filename = testfilelist[i]
fileStr = filename.split('.')[0]
rightNumber = int(fileStr.split('_')[0])
testVector = img2vector('digits/testDigits/{0}'.format(filename))
classr = classify0(testVector,trainmat,hwlables,3)
print('the classifier came back is {0},the real answer is : {1}'.format(classr,rightNumber))
if classr != rightNumber:
errorcount+=1
print('the total error number is {0}'.format(errorcount))
print('the total error rate is {0}'.format(errorcount/float(mTest)))
if __name__ == "__main__":
group,lables = createDataset()
r = classify0([0,0],group,lables,3)
print(r)
datingmat,classlab = datingfileToMatrix("datingTestSet2.txt")
normMat,matranges,matminvals = autoNorm(datingmat)
# print(normMat)
# print(matranges)
# print(matminvals)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingmat[:,0],datingmat[:,1],15.0*array(classlab),15.0*array(classlab))
plt.show()
# datingClassTest()
# handwritingclassTest()
总结
k近邻算法的主要优点:
- 对异常数据不敏感
缺点就是计算复杂度太高了,空间复杂度相对较高,他需要将带预测的数据与样本数据逐个计算距离,如果样本数据达到几百万条,这样的计算量是非常大的