KNN算法
KNN(K-nearst neighbors)
1,理论表述:
KNN算法的全称为K近邻算法,结合英文不难理解,就是每个样本都可以用和他最邻近的k个数据所代表的类别来刻画。
这个算法不仅能应用到分类算法中,也能应用到回归算法中。那他究竟能干什么事情呢?下面看看他的实际应用吧(分类):
上面我们要对唐人街探案电影进行分类。
如何分类呢?我们不妨统计特定的镜头来当做参数,
然后根据特定的距离公式算法,找出来了六个与唐人街探案最近的“邻居”,其中五个为喜剧片,一个为爱情片,所以我们把唐探归到喜剧片中。再举一个例子,假如你想知道一个陌生人是一个怎么样的人,怎么办呢?你可以通过与他关系好的几个人中的共性大致的推断出他的为人处世。
2,直观表示:
KNN实现
1,距离的计算
从上面不难看出,要使用KNN算法,不可避免要涉及到距离的计算问题。下面介绍一些常用的距离算法:
(1)欧氏距离:最常见的两点之间或多点之间的距离表示法,又称之为欧几里得度量,它定义于欧几里得空间中。n维空间中两个点P(x1,x2,…,xn)与 Q(y1,y2,y3…,yn)间的距离:
d
=
∑
i
=
1
n
(
x
i
−
y
i
)
2
d=\sqrt {\displaystyle\sum_{i=1}^{n} (x_i-y_i)^2}
d=i=1∑n(xi−yi)2
(2) 曼哈顿距离:曼哈顿距离对应L1-范数,也就是在欧几里得空间的固定直角坐标系上两点所形成的线段对轴产生的投影的距离总和。例如在平面上,坐标(x1, y1)的点P1与坐标(x2, y2)的点P2的
∣
x
1
−
x
2
∣
+
∣
y
1
−
y
2
∣
|x_1-x_2|+|y_1-y_2|
∣x1−x2∣+∣y1−y2∣要注意的是,曼哈顿距离依赖座标系统的转度,而非系统在坐标轴上的平移或映射。
L
1
=
∑
i
=
1
n
∣
x
1
i
−
x
2
i
∣
L_1=\sum_{i=1}^{n}|x_{1i}-x{2i}|
L1=i=1∑n∣x1i−x2i∣
(3)切比雪夫距离:n维空间点a(x11,x12,…,x1n)与b(x21,x22,…,x2n)的切比雪夫距离:
d
=
m
a
x
(
x
1
i
−
x
2
i
)
d=max(x_{1i}-x_{2i})
d=max(x1i−x2i)
(4)余弦距离
(5)闵可夫斯基距离
。。。。距离的计算有好多种,在这里不一一罗列了,有兴趣的朋友可以多了解一下。
2,k值的确定
为什么会牵扯到k的取值问题呢?因为你会发现,赋予k一个怎样的值,直接影响着分类的结果。
k的取值无外乎两种,过大或过小。
(1)k值过小
k的取值过小,取到的已分类样品较少,无法对需分类物品进行分类。那就回导致过拟合,也就是模型过于拟合训练集数据,只有训练集极近的样品才能被分类。基本没有实际价值。
(2)k值过大
值过于大,与小相比,那就是欠拟合,那什么是欠拟合呢?
欠拟合可以理解为与样品点距离较远的点也会对分类结果造成不必要的影响,使得分类的结果偏差较大。
(3)取值红灯
1,k的取值不能是类别数的倍数。就比如我需要把一个水果分到已知的苹果和香蕉两类中,那么k的取值就不能是2,4,6,8…2n因为会出现apple(n)=banana(n)的情况,那么就会出现分类错误。
2,当我们取得值与多个计算的距离值相等时,我们就要考虑k的重新取值了,当然也可以换一种距离的计算方法。
kNN应用(鸢尾花数据集)
1,前提数据准备
Iris 鸢尾花数据集内包含 3 类,分别为山鸢尾(Iris-setosa)、变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica),共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度(sl)、花萼宽度(sw)、花瓣长度(pl)、花瓣宽度(pw),可以通过这 4 个特征预测鸢尾花卉属于哪一品种。
数据获取可参照本人另一篇博客的最后一个实例
本实例用到的数据如下图所示:
共150行,五列。
2,代码实现
由于为了方便应用欧式距离计算,这里只取前两个自变量,sl和sw。
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from math import sqrt
##### 数据准备 ####
iris = pd.read_csv('E:\Python\Iris.csv')
num_iris = len(iris)
#将 3 种类型分别映射为 0,1,2
iris["type"] = iris["type"].map({"Iris-setosa":0,"Iris-versicolor":1,"Iris-virginica":2})
#定义一个测试集
test_data = [[5.5,5.2,7.0,5.6],[3.6,2.3,2.9,3.2]]#[[花萼长度],[花萼宽度]],一共四个测试数据
print("测试数据:")
print("花萼长度:",end='')
print(test_data[0])
print("花萼宽度:",end='')
print(test_data[1])
##### 数据整理 #####
iris_0 = [[],[]]
iris_1 = [[],[]]
iris_2 = [[],[]]#分别对应三种鸢尾花,两个子列表分别存储前两列的数据
iris_type = iris.type
for i in range(num_iris):
if iris_type[i] == 0:
iris_0[0].append(iris.sl[i])
iris_0[1].append(iris.sw[i])
elif iris_type[i] == 1:
iris_1[0].append(iris.sl[i])
iris_1[1].append(iris.sw[i])
else:
iris_2[0].append(iris.sl[i])
iris_2[1].append(iris.sw[i])
#####KNN 对测试集进行分类 ######
#定义欧氏距离
def Euclid(x1,y1,x2,y2):
d = sqrt((x1-x2)**2+(y1-y2)**2)
return d
def _KNN_(x,y):
K = 5
#====== 计算距离 =======#
distance_0 = []
distance_1 = []
distance_2 = []
distances = []
#计算并记录距离
for i in range(50):
d = Euclid(x,y,iris_0[0][i],iris_0[1][i])
distance_0.append(d)
for i in range(50):
d = Euclid(x,y,iris_1[0][i],iris_1[1][i])
distance_1.append(d)
for i in range(50):
d = Euclid(x,y,iris_2[0][i],iris_2[1][i])
distance_2.append(d)
#由小到大排序(此处使用冒泡排序)
distances = distance_0 + distance_1 + distance_2
for i in range(len(distances)-1):
for j in range(len(distances)-i-1):
if distances[j] > distances[j+1]:
distances[j],distances[j+1]=distances[j+1],distances[j]
#======== 决策划分 ========#
#定义删除函数,避免对同一个数据重复计算
def delete(a,b,ls):
for i in range(b):
if ls[i] == a:
ls.pop(i)
break
#找出与测试数据最接近的 K 个点
number_0 = number_1 = number_2 = 0
for i in range(K):
if distances[i] in distance_0:
number_0 += 1
delete(distances[i],len(distance_0),distance_0)
continue
if distances[i] in distance_1:
number_1 += 1
delete(distances[i],len(distance_1),distance_1)
continue
if distances[i] in distance_2:
number_2 += 1
delete(distances[i],len(distance_2),distance_2)
continue
max_number =max(number_0,number_1,number_2)
if max_number == number_0:
return 0
elif max_number == number_1:
return 1
else:
return 2
print("这四个测试数据分别属于:")
for i in range(len(test_data[0])):
m=_KNN_(test_data[0][i],test_data[1][i])
if m == 0:
print("Iris-setosa")
elif m == 1:
print("Iris-versicolor")
else:
print("Iris-virginica")
##### 画图 #####
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
#训练集
plt.scatter(iris_0[0],iris_0[1],marker='o',label='Iris-setosa',color='blue')
plt.scatter(iris_1[0],iris_1[1],marker='x',label='Iris-versicolor',color='black')
plt.scatter(iris_2[0],iris_2[1],marker='s',label='Iris-virginica',color='green')
#测试集
plt.scatter(test_data[0],test_data[1],marker='^',label='测试数据',color='red')
plt.legend()
plt.show()