K-近邻算法
1. 工作原理:
首先它是用来分类,也比较简单。举个西瓜的例子,现在训练集中有10个西瓜,我们观察西瓜的三个属性,色泽、根蒂、敲声,分别得到对应的属性值。并且,我们已经知道这10个瓜是好瓜还是坏瓜!那么,现在有个新的西瓜样本,我们不知道它是不是好瓜,我们只知道它的三个属性,所以,这个时候我们对它进行判断,即分类。
2.算法过程:
以上面的例子简单讲一下,顾名思义,找K个最接近的样本,然后看它们大都数属于哪一类。步骤如下:
(1)计算新西瓜与训练集中10个西瓜的距离
(2)根据计算结果,找与新西瓜最近的K个
(3)观察k个西瓜中,有多少是好西瓜有多少是坏西瓜,选择比例大的。如:k=5,你发现4个都为好瓜, 那么这个新西瓜就为好瓜,反之亦然
3.Python代码
这里为了简单,没用西瓜的例子,就简单数值,我随机写了数据集,ok,闲话少说,直接上代码!
# coding = utf-8
from numpy import *
import operator
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
# 距离计算(为了方便选择欧氏距离)
diffMat = tile(inX, (dataSetSize, 1)) - dataSet # 测试集与训练集做差
sqDiffMat = diffMat ** 2 # 对矩阵中的每个元素做平方
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort() # 数组安值大小返回索引
# print sortedDistIndicies
classCount = {}
# 选择最小的k的点
for i in range(k):
voteLabel = labels[sortedDistIndicies[i]] # 找到第一个最近的样本的label
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 测试一个例子
group, label = createDataSet()
print classify([0, 0], group, label, 3)
运行的结果:
即,把[0,0]输入进去,得到的是B类
总结:
这里只是用简单的例子做简单的介绍,读者也可以用其余的数据集去测试,因为我这里的数据是两维的,在现实中,可能是多维的。