1 实验名称
使用Python实现KNN
2 实验目的
1.掌握KNN的算法原理
2.掌握math、collections、numpy等第三方包的基本使用
3 实验背景
KNN(K-NearestNeighbor,译K最近邻)分类算法是数据挖掘分类技术中最简单的方法之一,也是最常用的分类算法之一。
所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法 。
本实验将演示使用Python的第三方库Numpy实现KNN算法,对训练样本进行建模,并对待测样本进行类别划分。
4 实验原理
K近邻(K Nearest Neighbors,KNN)算法,又称为KNN算法,是一种非常直观并且容易理解和实现的有监督分类算法。该算法的基本思想是寻找与待分类的样本在特征空间中距离最近的K个已标记样本(即K个近邻),以这些样本的标记为参考,通过投票等方式,将占比例最高的类别标记赋给待标记样本。该方法被形象地描述为“近朱者赤,近墨者黑”。
由算法的基本思想可知,KNN分类决策需要待标记样本与所有训练样本做比较,不具有显式的参数学习过程,在训练阶段仅仅是将样本保存起来,训练时间为零,可以看作直接预测。
KNN算法需要确定K值、距离度量和分类决策规则。
5 实验步骤(简单测试)
5.1 导入所需要的包
#导入所需要的包
import numpy as np
from math import sqrt
from collections import Counter
5.2 定义KNN函数
# KNN算法规则:
# 1、确定K值(一般取较小的奇数)
# 2、距离度量(欧氏距离)
# 3、分类决策(多数表决)
def distance(k, X_train, Y_train, x):
#shape三种使用方法:shape[0],shape[1],shape
# 1、shape[0] :读取行数
# 2、shape[1]:读取列数
# 3、shape:行列数组成元组直接输出
#保证K有效
assert 1 <= k <= X_train.shape[0], "K must be valid"
#X_train的值必须等于y_train的值
assert X_train.shape[0] == Y_train.shape[0], "the size of X_train must equal to the size of y_train"
#x的特征号必须等于X_train X_train的列 2个特征 x 2个元素
assert X_train.shape[1] == x.shape[0], "the feature number of x must be equal to X_train"
#迅速计算距离 列表生成器 返回列表
distance = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
#返回距离值从小到大排序后的索引值的数组
nearest = np.argsort(distance)
#获取距离最小的前k个样本的标签
topk_y = [Y_train[i] for i in nearest[:k]]
#统计前k个样本的标签类别以及对应的频数
votes = Counter(topk_y)
#返回频数最多的类别
return votes.most_common(1)[0][0]
5.3 显示待测样本的标签完成分类
if __name__ == "__main__":
#使用numpy生成8个点 已量化训练样本
X_train = np.array([[1.0, 3.5],
[2.0, 7],
[3.0, 10.5],
[4.0, 14],
[5, 25],
[6, 30],
[7, 35],
[8, 40]])
#使用numpy生成8个点对应的类别 训练样本标签
Y_train = np.array([0, 0, 0, 0, 1, 1, 1, 1])
#使用numpy生成待分类样本点 待测样本
x = np.array([8, 21])
#调用distance函数并传入参数 k=3
label = distance(3, X_train, Y_train, x)
#显示待测样本点的分类结果
print(label)
提示:shape用法
print("X_train.shape[0]:",X_train.shape[0])
print("Y_train.shape[0]:",Y_train.shape[0])
print("X_train.shape[1]:",X_train.shape[1])
# 一维数组shape[0]返回元素个数,二维数组shape[0]返回行,shape[1]返回列
print("x.shape[0]:",x.shape[0])