一、k-近邻算法概述
简答的说k近邻算法采用测量不同特征值之间的距离方法进行分类。
工作原理是: 存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输人没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
k-近邻算法的一般流程
(1)收集数据:可以使用任何方法。
(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式。
(3)分析数据:可以使用任何方法。
(4)训练算法:此步驟不适用于k 近邻算法。
(5)测试算法:计算错误率。
(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行女-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。
二、K近邻算法–约会网站配对
1、案例介绍
我的朋友海伦一直使用在线约会网站寻找适合自己的约会对象。尽管约会网站会推荐不同的人选,但她没有从中找到喜欢的人。经过一番总结,她发现曾交往过三种类型的人:
□ 不喜欢的人
□ 魅力一般的人
□ 极具魅力的人
海伦收集了自己的一些约会记录的数据信息。每个样本数据占据一行,总共有1000行。
数据文件 datingTestSet2.txt
海伦的样本数据局主要包含以下3种特征:
□ 每年获得的飞行常客里程数
□ 玩视频游戏所耗时间百分比
□ 每周消费的冰淇淋公升数
她希望通过以上这三个特征来进行判断自我感觉。
2、准备数据、分析数据
1、在将上述特征数据输人到分类器之前,必须将待处理数据的格式改变为分类器可以接受的格式。
2、接着我们需要了解数据的真实含义。当然我们可以直接浏览文本文件,但是这种方法非常不友好,一般来说,我们会采用图形化的方式直观地展示数据。
from numpy import *
from matplotlib.font_manager import FontProperties
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import operator #导入需要用到的包
def filematrix(filename):
fr=open(filename) #打开文件
arrayoflines=fr.readlines() #读取文件
numberoflines=len(arrayoflines) #得到文件的行数
returnmat=zeros((numberoflines,3)) #returnmat是一个numberoflines行3列的0矩阵,用来存放feature
classlabelvector=[] #设置一个空list来存label
index=0
for line in arrayoflines: #一行一行读数据
line=line.strip() #截取所有回车字符
listfromline=line.split('\t') #将上一步得到的整行数据分割成一个列表
returnmat[index,:]=listfromline[0:3] #注意矩阵的索引和列表的索引 #1-3列是特征列,依次将特征存放到returnmat中
#下面两行代码和下下面的代码。注意输出是数字还是字符串
a = float(listfromline[-1]) # 列表转换为数字
classlabelvector.append(a) #最后一列也就是第4列是label列,存放到classlabelvector中
#对于datingTestSet文档好使
# if listfromline[-1] == 'didntLike':
# classlabelvector.append(1)
# elif listfromline[-1] == 'smallDoses':
# classlabelvector.append(2)
# elif listfromline[-1] == 'largeDoses':
# classlabelvector.append(3)
index+=1
return returnmat,classlabelvector #返回特征矩阵和标签列表
def showdatas(datingDataMat, datingLabels):
#设置汉字格式
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
#将fig画布分隔成1行1列,不共享x轴和y轴,fig画布的大小为(13,8)
#当nrow=2,nclos=2时,代表fig画布被分为四个区域,axs[0][0]表示第一行第一个区域
fig, axs = plt.subplots(nrows=2, ncols=2,sharex=False, sharey=False, figsize=(13,8))
numberOfLabels = len(datingLabels)
LabelsColors = []
for i in datingLabels:
if i == 1:
LabelsColors.append('black')
if i == 2:
LabelsColors.append('orange')
if i == 3:
LabelsColors.append('red')
#画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第二列(玩游戏)数据画散点数据,散点大小为15,透明度为0.5
axs[0][0].scatter(x=datingDataMat[:,0], y=datingDataMat[:,1], color=LabelsColors,s=15, alpha=.5)
#设置标题,x轴label,y轴label
axs0_title_text = axs[0][0].set_title(u'每年获得的飞行常客里程数与玩视频游戏所消耗时间占比',FontProperties=font)
axs0_xlabel_text = axs[0][0].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)
axs0_ylabel_text = axs[0][0].set_ylabel(u'玩视频游戏所消耗时间占',FontProperties=font)
plt.setp(axs0_title_text, size=9, weight='bold', color='red')
plt.setp(axs0_xlabel_text, size=7, weight='bold', color='black')
plt.setp(axs0_ylabel_text, size=7, weight='bold', color='black')
#画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5
axs[0][1].scatter(x=datingDataMat[:,0], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)
#设置标题,x轴label,y轴label
axs1_title_text = axs[0][1].set_title(u'每年获得的飞行常客里程数与每周消费的冰激淋公升数',FontProperties=font)
axs1_xlabel_text = axs[0][1].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)
axs1_ylabel_text = axs[0][1].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)
plt.setp(axs1_title_text, size=9, weight='bold', color='red')
plt.setp(axs1_xlabel_text, size=7, weight='bold', color='black')
plt.setp(axs1_ylabel_text, size=7, weight='bold', color='black')
#画出散点图,以datingDataMat矩阵的第二(玩游戏)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5
axs[1][0].scatter(x=datingDataMat[:,1], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)
#设置标题,x轴label,y轴label
axs2_title_text = axs[1][0].set_title(u'玩视频游戏所消耗时间占比与每周消费的冰激淋公升数',FontProperties=font)
axs2_xlabel_text = axs[1][0].set_xlabel(u'玩视频游戏所消耗时间占比',FontProperties=font)
axs2_ylabel_text = axs[1][0].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)
plt.setp(axs2_title_text, size=9, weight='bold', color='red')
plt.setp(axs2_xlabel_text, size=7, weight='bold', color='black')
plt.setp(axs2_ylabel_text, size=7, weight='bold', color='black')
#设置图例
didntLike = mlines.Line2D([], [], color='black', marker='.',
markersize=6, label='didntLike')
smallDoses = mlines.Line2D([], [], color='orange', marker='.',
markersize=6, label='smallDoses')
largeDoses = mlines.Line2D([], [], color='red', marker='.',
markersize=6, label='largeDoses')
#添加图例
axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses])
axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses])
axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses])
"""
1.tight_layout命令:主要用于自动调整绘图区的大小及间距,
使所有的绘图区及其标题、坐标轴标签等都可以不重叠的完整显示在画布上。
2、tight_layout命令还有三个关键字参数:pad、w_pad、h_pad。
pad用于设置绘图区边缘与画布边缘的距离大小
w_pad用于设置绘图区间水平距离的大小
h_pad用于设置绘图区间垂直距离的大小
"""
fig.tight_layout(pad=2, w_pad=3.0, h_pad=3.0)
#显示图片
plt.show()
datingdatamat,datinglabel=filematrix('datingTestSet2.txt')
showdatas(datingdatamat,datinglabel)
显示如下:
3、归一化数据:我们很容易发现,在用欧拉公式(如下)计算两个样本数据距离时,方程中数字差值最大的属性对计算结果的影响最大,也就是说,每年获取的飞行常客里程数对于计算结果的影响将远远大于其他两个特征— 玩视频游戏的和每周消费冰洪淋公升数— 的影响。
而产生这种现象的唯一原因,仅仅是因为飞行常客里程数远大于其他特征值。但海伦认为这三种特征是同等重要的,因此作为三个等权重的特征之一,飞行常客里程数并不应该如此严重地影响到计算结果。:在处理这种不同取值范围的特征值时,我们通常采用的