1、KNN模型简介
KNN模型,是最简单的机器学习算法之一,其作用是以全部训练样本作为代表点,通过计算未知样本与所有训练样本的距离,并以最近邻者的类别作为决策未知样本类别的依据;即根据测试数据与k个已知点的最短距离来划分未知数据的类别。为更好理解其算法原理,我们对几组概念进行说明。
训练数据,即用于训练模型的数据。所谓训练数据,其本质就是将该数据作为模型运算的已知数据;换句话讲,以后模型的运算都是基于该训练数据。
训练数据,即需要判断类别的数据。
2、KNN算法可视化理解
为了更好理解其算法原理,我们以下列图片为例进行讲解。
如图,其中蓝色方块与红色三角形是已知数据(即训练数据),现要判断绿色圆形是属于,还是属于?
可以很容易知道,当k=3时,即根据周围最近的3个已知样本的距离进行判断,其中有2个而有1个,因此在k=3的情况下属于。同理,当k=6时, 只有2个而有4个,因此在k=6的情况下属于。
其数学表达式为:
我们已经知道KNN是以最近邻者的类别作为决策未知样本类别的依据,那么怎么计算距离呢?距离度量的方式有三种:欧式距离、曼哈顿距离、闵可夫斯基距离,通常用的是欧式距离(即两点间的直线距离)。
3、KNN的算法过程
KNN模型的三要素:k值的选取,距离度量方式和分类决策规则。
k值的选取和距离度量方式好理解。重点是分类决策规则,分类决策规则一般使用多数表决法(即少数服从多数)。通俗讲,就是一个未知样本的类别取决于它周围k个最近的样本中所占数量最多的类别。
算法过程:
- 输入训练集数据和标签,输入测试数据;
- 计算测试数据与各个训练数据之间的距离;
- 按照距离的递增关系进行排序,选取距离最小的K个点;
- 根据K个点中各类别所占的比例,返回比例最大的类别作为测试数据的类别。
4、代码实现:
#导包
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams['font.sans-serif']=['SimHei'] #为了在绘图时能显示中文
plt.rcParams["axes.unicode_minus"] = False #为了在绘图时能显示负坐标轴的负号
#导入训练数据
train_data=np.loadtxt(r'C:\Users\trainData.txt')
#获取ndarray对象input_data所有维数的第3个元素,[:,3]中:表示遍历input_data所有维数,3表示每一维中的索引为2的元素
train_data_label=train_data[:,2]
#已知类别,储存类别对应的坐标
#where函数,用两组索引数组来表示值的位置,返回的第一个array表示行坐标,第二个array表示纵坐标
xy_1=train_data[np.where(train_data_label==1)]
xy_2=train_data[np.where(train_data_label==2)]
#scatter(),散点图绘画,s代表点的大小
#[:,0]中:表示遍历input_data所有维数,3表示每一维中的索引为2的元素
plt.scatter(xy_1[:,0],xy_1[:,1],color='r',s=2,label="训练数据 1")
plt.scatter(xy_2[:,0],xy_2[:,1],color='b',s=2,label="训练数据 2")
#导入测试数据,此处为生成10000个test数据
test_x=np.linspace(-6,5,100) #linspace()函数用于创建一个由等差数列构成的一维数组,用来均匀创建test数据的坐标
test_y=np.linspace(-5,5,100)
k=int(input('请输入奇数K:'))#确定K的取值,为方便
x1=[]#用于储存test数据中属于第一类数据的x坐标
y1=[]#用于储存test数据中属于第一类数据的y坐标
x2=[]
y2=[]
#enumerate函数用来遍历传入的对象test_y,并返回对应元素的索引,j为索引,y为y轴坐标
for j,y in enumerate(test_y):
for i,x in enumerate(test_x):
knn=np.empty([0,2],dtype=int)#用empty函数创建一维二列的数组,可以进行行向添加数组
for point in train_data:
dist=np.sqrt((x-point[0])**2+(y-point[1])**2)#计算测试点与训练点的距离
if dist>=0:
knn=np.append(knn,[[point[2],dist]],axis=0)#将训练点的类别与计算的距离平方进行绑定,并储存添加到数组knn中
maxdist=max(knn[:,1])#用max函数获取knn存储数据中距离的最大值
if knn.shape[0]>k:#保证knn存储的数据个数为k个
knn=np.delete(knn,np.where(knn[:,1]==maxdist),axis=0)#在每次循环中删除最大距离
#此处更新每次循环的最大值而不用sort函数排序,最大程度上减少了内存开销,提升运行速度
label_1=0#用于统计K个点中各类别的个数
label_2=0
for kn in knn:#遍历knn中的元素
if kn[0]==1:
label_1+=1
else:
label_2+=1
if label_1>label_2:
x1.append(x)
y1.append(y)
else:
x2.append(x)
y2.append(y)
#绘图
plt.scatter(x1,y1,color='r',marker='o',s=0.5,alpha=0.8,label="属于第一类")
plt.scatter(x2,y2,color='b',marker='*',s=0.5,alpha=0.8,label="属于第二类")
plt.legend()#显示图例
plt.show()
5、训练样本以及检测结果
训练数据散点图
检测结果
本文仅供学习交流