一,k近邻算法概述
k近邻算法是一种简单有效但并不高效的非线性分类方法。
- 优点:精度高,对异常值不敏感、无数据输入假设。
- 缺点:计算复杂度高、空间复杂度高。
- 使用数据范围:离散型和连续型。
二,k近邻算法的核心步骤
对未知类别属性的数据集中的每一个点依次执行以下操作:
1. 计算已知数据集中的点与当前点之间的距离。
2. 按照距离递增次序排序。
3. 选取与当前点距离最小的k个点。
4. 确定前k个点所在类别的出现频率。
5. 返回前k个点出现频率最高的类别作为当前点的预测分类。
三,k近邻算法应用的一般流程
- 收集数据:可以使用任何方法。例如:存储到数据库(mysql、mongodb等)或者直接存储成文本文件。
- 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
- 分析数据:可以使用任何方法。例如:用matplotlib画二维扩散图。
- 训练算法:此步骤不适用于k近邻算法,因为k近邻直接基于实例,无需训练。
- 测试算法:在测试集上计算错误率。
- 使用算法:首先输入样本数据和结构化的输出结果,然后运行k近邻算法判定输入数据属于哪一类别,最后应用对计算出的分类执行后续的处理。
四,k近邻算法应用的Python3代码实现
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
# 1.收集数据省略,因为已给出数据。
# 4.训练算法省略,因为k近邻算法无需训练。
# k近邻算法核心步骤
def classify0(inX,dataSet,labels,k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX,(dataSetSize,1))-dataSet
sqDiffMat = diffMat**2
sqDistances = sum(sqDiffMat,axis=1)
sortedDistIndicies = sqDistances.argsort()
classCount = {}
for i in range(k):
voteILabel = labels[sortedDistIndicies[i]]
classCount[voteILabel] = classCount.get(voteILabel,0)+1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
# 2.准备数据
def file2matrix(filename):
fr = open(filename)
fileLines = fr.readlines()
numsOfLines = len(fileLines)
returnMat = zeros((numsOfLines,3))
index = 0
classLabelVector = []
for line in fileLines:
listOfLine = line.strip().split('\t')
returnMat[index,:] = listOfLine[:3]
classLabelVector.append(int(listOfLine[-1]))
index += 1
return returnMat,classLabelVector
# 3.分析数据
def plotData(DataSet,Labels,i,j):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(DataSet[:,i],DataSet[:,j],15.0*array(Labels),15.0*array(Labels))
plt.show()
# 对数据进行数值归一化
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
normDataSet = zeros(dataSet.shape)
n = dataSet.shape[0]
normDataSet = dataSet - tile(minVals,(n,1))
ranges = maxVals - minVals
normDataSet = normDataSet/tile(ranges,(n,1))
return normDataSet,ranges,minVals
# 5.测试算法
def datingClassTest():
hoRatio = 0.1
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')
normDataSet,ranges,minVals = autoNorm(datingDataMat)
n = normDataSet.shape[0]
numTestSet = int(n*hoRatio)
errorCount = 0
for i in range(numTestSet):
classifyResult = classify0(normDataSet[i,:],normDataSet[numTestSet:n,:],datingLabels[numTestSet:n],3)
print('The classifier result is: %d,and the real answer is:%d。'%(classifyResult,datingLabels[i]))
if classifyResult!=datingLabels[i]:
errorCount += 1
print('The total error rate is: %f'%(errorCount/numTestSet))
# 6.使用算法
def classifyPerson():
resultList = ['not at all','is small doses','in large doses']
Game = float(input('please input the percentage of time spent in playing video games:'))
FlyMiles = float(input('please input the fly miles:'))
iceCream = float(input('please input the liters of ice cream consumed per year:'))
inArr = array([FlyMiles,Game,iceCream])
dateSet,dateLabels = file2matrix('datingTestSet2.txt')
normData,ranges,minVals = autoNorm(dateSet)
result = classify0((inArr-minVals)/ranges,normData,dateLabels,3)
print('You will probably like this person:',resultList[result-1])
附录-Python程序中用到的函数
1, numpy.tile(A,B)函数:,重复A,B次,这里的B可以时int类型也可以是元组类型。
>>> import numpy
>>> numpy.tile([0,0],5)#在列方向上重复[0,0]5次,默认行1次
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
>>> numpy.tile([0,0],(1,1))#在列方向上重复[0,0]1次,行1次
array([[0, 0]])
>>> numpy.tile([0,0],(2,1))#在列方向上重复[0,0]1次,行2次
array([[0, 0],
[0, 0]])
>>> numpy.tile([0,0],(3,1))
array([[0, 0],
[0, 0],
[0, 0]])
>>> numpy.tile([0,0],(1,3))#在列方向上重复[0,0]3次,行1次
array([[0, 0, 0, 0, 0, 0]])
>>> numpy.tile([0,0],(2,3))<span style="font-family: Arial, Helvetica, sans-serif;">#在列方向上重复[0,0]3次,行2次</span>
array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]])
2, operator模块中的常用函数。例如使用 itemgetter() 从元组序列中获取指定的域值。
>>> inventory = [('apple', 3), ('banana', 2), ('pear', 5), ('orange', 1)]
>>> getcount = operator.itemgetter(1)
>>> map(getcount, inventory)
[3, 2, 5, 1]
>>> sorted(inventory, key=getcount)
[('orange', 1), ('banana', 2), ('apple', 3), ('pear', 5)]
3, python2中的iteritems()在python3中变为了items():
在Python2.x中,items( )用于 返回一个字典的拷贝列表【Returns a copy of the list of all items (key/value pairs) in D】,占额外的内存。
iteritems() 用于返回本身字典列表操作后的迭代【Returns an iterator on all items(key/value pairs) in D】,不占用额外的内存。
Python 3.x 里面,iteritems() 和 viewitems() 这两个方法都已经废除了,而 items() 得到的结果是和 2.x 里面 viewitems() 一致的。在3.x 里 用 items()替换iteritems() ,可以用于 for 来循环遍历。
4, matplotlib模块中的subplot()方法。
subplot(numRows, numCols, plotNum)
subplot将整个绘图区域等分为numRows行* numCols列个子区域,然后按照从左到右,从上到下的顺序对每个子区域进行编号,左上的子区域的编号为1。如果numRows,numCols和plotNum这三个数都小于10的话,可以把它们缩写为一个整数,例如subplot(323)和subplot(3,2,3)是相同的。subplot在plotNum指定的区域中创建一个轴对象。如果新创建的轴和之前创建的轴重叠的话,之前的轴将被删除。
import matplotlib
import matplotlib.pyplot as plt
for i,color in enumerate("rgbyck"):
plt.subplot(321+i,axisbg=color)
plt.show()
效果如下: