k-近邻算法介绍
首先说明一下什么叫做k-近邻算法,k-近邻算法简称knn算法,其主要思想是当输入一个点(坐标)时,算法将找到离它最近的k个点,也就是这个点最近的k个邻居,对这k个点的标签进行分类,将这个点归类到k个邻居所归属的数目最多的分类中。
举个例子:
大家都知道“物以类聚,人以群分”这个成语吧,我们以此为一个场景,在钟吾国的某个城镇中,有陈、杨、李、周四个家族,总体上这个四个家族的成员是抱团定居的(也就是相同种姓的人更倾向于住在一起),现在我想知道某个屋子的主人是哪个家族的成员,但是门卫戒备森严,我们不能靠近,该怎么办呢?
首先我们,我们有一种思路,就是去调查他周围几个邻居,比如调查离宅子最近的五户人家(这样k=5了)是哪几个家族的人,假如五户人家有三户是陈家的,一户是杨家的,一户是李家的,那么这个大宅子是陈家的概率为3/5,杨家和李家的概率都是1/5,那么我们就认为这家是陈家的宅院。
这就是knn算法的大体思路啦。
代码详解
1.模块导入
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
我们一共导入了三个模块:
- numpy模块:这是一个科学计算包,它拥有丰富的数学函数,强大的多维数组以及优异的运算性能,主要用来对矩阵进行运算。
- operato模块:这是一个运算符模块,它为我们提供了大量可以用来替代代数操作符的函数,在之后我们会对使用到的部分函数进行介绍。
- matplotlib模块:这是一个用来完成平面画图的非常强大的模块。
2.构建分类器
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
sortedClassCount = sorted(
classCount.items(),
key=operator.itemgetter(1),
reverse=True)
return sortedClassCount[0][0]
我们为分类器定义了classify0函数,这个函数输入的参数有4个,第一个是inX,代表一个点的坐标(想知道姓氏的住宅的位置),第二个是dataSet:数据集(地图上所有住宅的位置),第三个是labels:标签集(地图上所有住宅所属的家族),最后一个是k,这是一个常量(选取的最近邻居的数目)。
我们首先介绍几个函数:
- tile()函数:这是由numpy提供的一个函数,他的功能是重复某个数组。
from numpy import *
a = [0, 1, 2]
b = tile(a, (1, 2))
c = tile(a, (2, 1))
输出b和c的结果是:[[0, 1, 2, 0, 1, 2]]和[[0, 1, 2], [0, 1, 2]],注意,c应该是上下排列的,我这里图省事了^ _ ^。
- sum(),顾名思义这是一个求和函数,在程序中有个(axis=0),这代表按行相加(第一列所有行相加、第二行所有列相加……),如果axis=1,则为按列相加(第一行所有列相加、第二行所有列相加……)。
import numpy as np
a = np.array([[1, 2], [3, 4]])
print(np.sum(a, axis=0))
print(np.sum(a, axis=1))
我们输出后得到:b[4, 6], c[3, 7]
- argsort()函数:这个函数返回的是将数组从小到大排列后的索引
import numpy as np
a = np.array([1, 3, 2, 4])
print(np.argsort(a))
输出结果为[0, 2, 1, 3],需要注意的是,索引值是从0开始排序的!
- get()函数,这是在字典中的应用,返回的是指定键的值。
- items()函数,这个函数以列表返回可遍历的(键, 值) 元组数组。
dict = {'Google': 'www.google.com', 'baidu': 'www.baidu.com', 'CSDN': 'www.csdn.net'}
print("字典值 : %s" % dict.items())
# 遍历字典列表
for key, values in dict.items():
print(key, values)
结果为:
- itemgetter()函数:这个函数就是由之前提到的operator模块提供的啦!它可以获取对象某些维度的数据
欧克,分类器中需要用到的函数我们已经介绍完了,接下来我们讲一下这个分类器。这个分类器主要用的是欧式距离公式,不清楚欧式距离公式看这里👉欧几里得度量,这也是knn的核心公式啦。
下面我们来看具体操作,我们首先把目标点inX的坐标拓展成和数据集同样的形式(shape[0]行1列),然后和数据集相减,这样我们得到的就是数据集中每个点和inX的(x1-x2,y1-y2)组成的数组,我们再平方和就可以得到数据集中每个点和点inX的距离,我们将距离从小到大排列,取前k个,看哪个类别最多返回出来就行啦。(在这里的处理方法是将k个最近邻居的标签插入到classCount中,再对所有classCount中的标签分类计数,将数目最多的标签返回出来。)
我们已经知道了knn算法的工作原理,可是如何将这个算法应用到实际中是一个难题,那么我们看看如何将knn应用到约会网站的匹配中吧!
3. 读取并处理文件
我们都知道机器学习的使用第一步就是从数据集中获取数据,没有数据,再好的算法都不可能有用武之地。读取数据是一个很简单的操作,大部分算法对数据的读取都是大同小异的。
在这里呢,我们的数据集拥有三个特征,分别是:每年获得的飞行常客里程数、玩游戏视频所耗百分比、每周消费的冰淇淋公升数,在应用机器学习之前对数据集进行一定的了解是十分必要的。接下来让我们来看代码是如何将文件的数据读成所需要的矩阵的~
def file2matrix(filename):
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLine = len(arrayOLines)
returnMat = zeros((numberOfLine, 3))
classLabelVector = []
index = 0
for line in arrayOLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector
首先,我们利用open()函数读入文件,按行为单位返回字符串,然后获得字符串的长度len,这就获得了数据的行数,那么我们就可以新建一个数据矩阵(数据集)啦,它拥有len行,3列,我们初始化每个位置都为0。然后我们对读取的字符串的每一行进行如下操作:
1. 去掉头尾的空格
2. 以转义字符\t为根据进行分割,并将前三列并放入数据矩阵的对应位置。
3. 将每行的最后一个数据,也就是标签放入标签集中
然后我们将数据集和标签集返回出来,就可以得到数据啦。
4.对数据进行可视化
我们常常很难直接从数据集中直接读出信息,如果我们使用某种方法将数据进行可视化,那么结果就可以变得明了的多,具体如下:
data, label = file2matrix('datingTestSet2.txt')
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(data[:, 0], data[:, 1], 15.0*array(label), 15.0*array(label))
plt.savefig('散点图.png')
plt.show()
制作出的散点图长这个样子👇!
其中横坐标是飞行里数,纵坐标每周游戏时间比,从图中我们可以发现这位小姐喜欢每年飞适当一段距离,每周也能陪她打一会游戏的,不喜欢那种需要飞较长时间的,也就是经常出差的!(无奈╮(╯▽╰)╭)
5.对数据进行归一化
之前的处理看似非常合理,但其实我们忽略了一个问题:
你们看,如果我需要计算两个点的距离,他们的坐标分别为(0, 20000, 1.1)(67, 32000, 0.1),那么两点距离为:
没错吧,但是我们发现,相对于飞行里程,冰淇淋公升数和游戏时间实在太微不足道了,但我们至少希望他们能够拥有相同的重要性,那么我们就需要对数据进行归一化,归一化的思想是不再原来的数值来表示数据,而是将数据缩小在(0,1)这个范围内,处理方法如下👇
这样,每个数值都控制在0到1中间了,具体代码如下👇
def autoNorm(dataSet):
minVals = dataSet.min(0) # 0使得函数从列表中选取最小值
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m, 1))
normDataSet = normDataSet / tile(ranges, (m, 1))
return normDataSet, ranges, minVals
这个函数十分简单,就是先获得每一列的最大值与最小值。然后进行对应的计算即可。最后将归一化的矩阵返回。
6. 测试
在前面我们已经完成了分类器的构建、对数据的读取与处理,但是我们并不知道我们所做的分类器的实际分类效果如何,这就需要我们自己对分类器进行测试。通常情况下,我们将数据的90%作为数据集,剩下的10%作为测试集,测试代码如下👇
def datingClassTest():
hoRatio = 0.1
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
# normMat:(oldnum-min)/(max-min)即归一化后的数据; ranges:max-min
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(
normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print(
"the classifer came back with: %d ,the real answer is: %d " %
(classifierResult, datingLabels[i]))
if(classifierResult != datingLabels[i]):
errorCount += 1.0
print("the total error rate is %f" % (errorCount / float(numTestVecs)))
我们读入数据后,提前设置好测试集所占的比例0.1,利用之前的函数将数据归一化,由于这个数据并没有排序,所以我们可以随机选择10%的数据放入测试集,这里我们选择的是前10%的数据,首先对测试集中每一个数据和数据集中的数据利用分类器进行分类,k取3,然后输出分类结果和正确值,如果不正确,错误数+1,最后用总错误数 / 测试集总数 获得错误率并输出出来。
在我的电脑上错误率为0.05%,还算差强人意吧。
7.构建完整的可用系统
经过上面的测试,我们可以知道分类器的分类结果是不错的,那么我们就可以应用起来,也就是对之前的所有函数进行整合与封装。代码如下👇
def classifyPerson():
resultList = ['not at all', 'in small does', 'in large does']
percentTatts = float(input("percentage of time spent playing video games"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of icecream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, percentTatts, iceCream])
classifierResult = classify0(
(inArr - minVals) / ranges,
normMat,
datingLabels,
3)
print("you will like this people ", resultList[classifierResult - 1])
我们首先为标签1、2、3分别赋值为not at all’, ‘in small does’, ‘in large does’,然后我们要求用户输入对象的三标,然后将预测结果输出出来就行啦。
如果大家觉得对自己有所帮助的话,就点个赞吧,比心心😘