KNN(K-Nearest Neighbors)
算法思想
基本思想:
基于给定的一个训练样本集合D和k值,现有待预测样本x(无标签);在D中找到与x距离最近的k个样本,①若是分类问题,则通过投票法选择这k个样本中出现次数最多的类别作为x的预测标签;②若是回归问题,对这k个样本的标签求平均值,得到x的预测结果。
算法三要素:k值,距离,k个近邻的快速检索方法
①k值的选取:当k值很小时,得到的结果可能偏差大;当k值很大时,可能会将大量其它类别的样本包含进来从而导致预测错误;经验选择:可通过交叉验证、在验证集上多次尝试不同的k值来挑选最佳k值。
②距离的度量:一般而言,任意两个样本A和B之间的距离为闵氏距离,即
,其中n为样本的特征数量,通常p取2(此时为欧式距离)。上述距离一般针对连续变量,对于离散变量,可先将离散变量连续化在应用距离计算。
③k个近邻的样本的快速检索方法:
常用的一个可取方法是为训练样本事先建立索引以减少计算的规模,如通过kd树(k-dimensional tree,这里的k不同于KNN中的k值,而是对应于样本的特征数量,即②中的n)。
基本定义:KD树是一种对k维空间数据进行存储以便对其进行快速检索的树形数据结构。KD树的每个节点都是一个k维点,并且所有非叶子节点可以视为使用一个超平面将空间分割成两个半空间,节点左边、右边的子树分别代表在超平面左边、右边的点。
构建过程
·选择分割维度:在构建KD树时,需要选择用于分割的维度。通常,这一选择基于数据在各个维度上的方差,选择方差最大的维度作为分割维度,因为这样可以获得最好的分辨率。
·确定分割点:对于选定的分割维度,选择该维度上数据的中位数作为分割点,这样可以确保分割尽可能平衡。
·递归构建:对于分割后的左右两部分数据,递归地应用上述过程,直到所有数据点都被分配到叶节点。
搜索过程
·最近邻搜索:如应用在KNN中的最近邻搜索。搜索过程从根节点开始,根据目标点与分割超平面的相对位置递归地进入左子树或右子树,直到达到叶节点。然后,通过回溯过程,检查是否有比当前最近点更近的点。。
优点与局限
KD树适合维数较低、训练集合规模大的样本数据上,此时能够高效地处理多维数据的搜索问题,尽可能地实现快速检索。但随着样本维度的增加,KD树的性能会急剧下降。
代码应用
在scikit-learn库中,KNeighborsClassifier类提供了KNN算法的实现。
从sklearn.neighbors导入KneighborsClassifier类并实例化,即创建一个KNN分类器的对象,即可根据这个对象进行KNN算法的执行。注意另外还有一个KneighborsRegressor类用于使用KNN算法进行回归问题的求解。
这个对象有多个参数(包含距离度量的指定和k个近邻的快速检索方法的指定等),核心参数是n_neighbors,它指定了用于预测的邻居数量k。其它参数参照官网文档。通过help()函数传入KneighborsClassifier即可查看该类的文档信息。
KneighborsClassifier对象的常用方法
fit(X, y)
功能:训练KNN分类器。这个方法计算所有训练样本之间的距离,并将这些信息存储在模型中,以便后续的分类任务。值得注意的是,实际上该方法并不会直接存储所有训练样本之间的距离,而是通过构建有效的数据结构(如kd树)和存储必要的参数来支持后续的快速距离计算和邻居查找。
·参数:(N为当前输入样本的总数量,num_features为每个样本的特征数量;下同)
X:训练数据的特征矩阵,形状为[N, num_features]。
y:训练数据的标签数组,形状为[N]。
·返回值:无返回值,但训练好的模型存储在knn对象中。
predict(X)
·功能:使用训练好的KNN分类器对新的数据点进行预测。
·参数:
X:待预测数据点的特征矩阵,形状为[N, num_features]。
·返回值:一个形状为[N]数组,包含了每个输入样本的预测类别。
score(X, y)
·功能:评估模型在给定测试集上的性能。默认情况下,它返回准确率(accuracy),即正确预测的样本数与总样本数的比例。
·参数:与fit()方法参数类似
·返回值:一个浮点数,表示拟合好模型的在当前样本集(X,y)上的预测准确率。
实验-鸢尾花分类
数据集
使用sklearn.datasets,模块中的load_iris()方法可获取鸢尾花数据,该方法返回的对象(这里将该对象命名为iris)包含该数据集数据及其相关信息。iris对象的data属性和target属性分别获取样本特征数据和样本标签信息,样本总数为150,样本特征数量、标签数量(类别数量)分别为4、3。iris的feature_names和target_names属性分别获取数据集的特征名称和标签名称;样本特征分别为sepal length 花萼长度、sepal width 花萼宽度、petal length 花瓣长度、petal width 花瓣宽度。
获取数据并拆分为训练集、测试集:
可视化数据集的数据点举例:
根据数组X(其形状为(N, 2),表示每个样本有两个特征)和数组Y(其形状为(N,),表示每个样本的类别标签)来绘制散点图,其中不同类别的样本用不同的颜色表示。
运行结果
模型拟合、评分、预测:
根据指定的k值创建KNN对象,然后通过fit()方法拟合训练集数据,score()方法获取本次KNN算法应用在训练集、测试集上的分类准确率,predict()方法获取全部样本的预测标签,然后打印基本的分类结果信息;通过上述流程对比不同k值对本次KNN实验结果的影响。注意,因为鸢尾花数据集的样本数量较少,k值的影响可能不太能体现,且实验中的k值要 小于等于 用于拟合的训练集的样本数量大小。
k=3时的测试样本的预测结果
注:关注微信公众号——分享之心,后台回复“机器学习基础实验”获取完整代码和相关文档资料的地址(不断更新)。
上一篇: 逻辑回归实现乳腺癌预测-numpy实现与MindSpore实现