k-近邻算法
简单的说,k-近邻算法采用测量不同特征值之间的距离方法进行分类;
优点:精度高、对异常值不敏感、无数据输入假定;
缺点:计算复杂度高、空间复杂度高;
适用范围:数值型和标称型;
工作原理:
存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签。
即我们知道样本集合中每一数据与所属分类的对应关系。
输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较。
然后算法提取样本集中特征最相似数据(最近邻)的分类标签。
一般来说,我们只选择样本数据中前k个最相似数据,这就是k-近邻算法中k的出处。
通常k是不大于20的整数。
最后选择k个最相似数据中出现次数最多的分类,作为新数据的分类;
k-近邻算法的一般流程
1、收集数据
可以使用任何方法;
2、准备数据
距离计算需要的数值,最好是结构化的数据格式;
3、分析数据
可以使用任何方法;
4、训练算法
此步骤不适合用于k-近邻算法;
5、测试算法
计算错误率;
6、使用算法
首先需要输入样本数据和结构化的输出结果;
然后运行k-近邻算法判定输入数据分别属于哪个分类;
最后应用于计算出的分类执行后续的处理;
准备工作
Python导入数据
先创建通用函数类:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
k-近邻算法 造个数据
@author: ge
"""
from numpy import *
import operator
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.1], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
retu
- 操作类:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
k-近邻算法 实现类
@author: ge
"""
import kNN
if __name__ == "__main__":
group, labels = kNN.createDataSet()
print(group)
print(labels)
实施kNN分类算法
- 伪代码
对未知类别属性的数据集中的每个点一次执行以下操作:
1、计算已知类别数据集中的点与当前点之间的距离;
2、按照距离递增次序排序;
3、选取与当前点距离最小的k个点;
4、确定前k个点所在类别的出现频率;
5、返回前k个点出现频率最高的类别作为当前点的预测分类;
看代码:
- 在刚才产生数据的kNN的类中加入方法:
# k-近邻算法
def classify0(inX, dataSet, labels, k):
# 距离计算 欧式距离公式
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
# 选择距离最小的k个点
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
- 调用:
print(kNNData.classify0([0, 0], group, labels, 3))
- 打印结果
B
即输入了[0,0]这个点,输出B,代表[0,0]点离B最近;
也可以试着输入其他值测试结果;
约会网站匹配优化
- 上面栗子实际生活中并没有什么卵用
下面看一个实战栗子,是约会网站的数据,通过三个特征值进行判断是否是喜欢的类型:
- 在datingTestSet2.txt中准备了一千条数据
- 三个特征值:
每年获得的飞行常客里程数
玩视频游戏所耗时间百分比
每周消费的冰淇淋公升数
- 通过上面的三个特征的综合考虑可以得出三种结果:
不喜欢的人
魅力一般的人
极具魅力的人
展示几个样本示例:
40920 8.326976 0.953952 3
14488 7.153469 1.673904 2
26052 1.441871 0.805124 1
75136 13.147394 0.428964 1编写分类器方法代码:
# 从文件中解析数据
def file2matrix(fileName):
fr = open(fileName)
arrayOfLines = fr.readlines()
numberOfLines = len(arrayOfLines) # 得到文件行数
returnMat = zeros((numberOfLines, 3)) # 创建返回的Numpy矩阵
classLabelVector = []
index = 0
# 解析文件数据得到列表
for line in arrayOfLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
- 调用
datingDataMat, datingLabels = kNN.file2matrix('datingTestSet2.txt')
print(datingDataMat)
print(datingLabels)
# 画图 第二列第三列数据展示 "玩视频游戏所耗时间百分比"和"每周所消费的冰淇淋公升数"
import matplotlib
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2])
plt.show()
打印结果
[[ 4.09200000e+04 8.32697600e+00 9.53952000e-01]
[ 1.44880000e+04 7.15346900e+00 1.67390400e+00]
[ 2.60520000e+04 1.44187100e+00 8.05124000e-01]
...,
[ 2.65750000e+04 1.06501020e+01 8.66627000e-01]
[ 4.81110000e+04 9.13452800e+00 7.28045000e-01]
[ 4.37570000e+04 7.88260100e+00 1.33244600e+00]]
[3, 2, 1, 1, 1, 1, ... 3, 3]
和一张散点图,但是散点图中并不能看出什么。
下面利用Matplotlib库提供的scatter函数支持个性化标记散点图上的点。
ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels))
再次运行,得到分了颜色的散点图,但是也没什么卵用。
以上是学习Matplotlib库图形化展示数据;
下面看看真正的约会网站好感评估:
- 数据准备
# 归一化特征值
def autoNorm(dataSet):
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1))
return normDataSet, ranges, minVals
调用:
normMat, ranges, minVals = kNN.autoNorm(datingDataMat)
print(normMat)
print(ranges)
print(minVals)
打印结果:
[[ 0.44832535 0.39805139 0.56233353]
[ 0.15873259 0.34195467 0.98724416]
[ 0.28542943 0.06892523 0.47449629]
…,
[ 0.29115949 0.50910294 0.51079493]
[ 0.52711097 0.43665451 0.4290048 ]
[ 0.47940793 0.3768091 0.78571804]]
[ 9.12730000e+04 2.09193490e+01 1.69436100e+00]
[ 0. 0. 0.001156]
- 作为完整程序验证分类器
将测试数据datingTestSet.txt放入
代码:
# 分类器针对约会网站的测试代码
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet.txt')
normat, ranges, minVals = autoNorm(datingDataMat)
m = normat.shape[0]
numTestVecs = int(m.hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normat[i, :], normat[numTestVecs:m, :], datingLabels[numTestVecs:m, 3])
print("the classifier came back with: %d, the real answer is: %d %(classifierResult, datingLabels[i])")
if(classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount / float(numTestVecs)))
测试调用:
kNN.datingClassTest()
即可输出测试正确率了
总结
1、k-近邻算法是分类数据最简单最有效的算法;
2、k-近邻算法是基于实例的学习,使用算法时我们必须有接近实际数据的训练样本数据;
3、k-近邻算法必须保存全部数据集,如果训练数据集很大,必须使用大量的存储空间;
4、由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时;
5、另一个缺陷是他无法给出任何数据的基础结构信息,因此也无法知晓平均实例样本和典型样本具体有什么特征;