目录
2. 使用 scikit - learn 实现 KNN 算法的步骤及详细解释
1. K 最近邻算法概述
K 最近邻(K-Nearest Neighbors,KNN)算法是一种基本的监督学习算法,既可以用于分类问题,也可以用于回归问题。其核心思想是:给定一个训练数据集,对于新的输入实例,在训练数据集中找到与该实例最邻近的 K 个实例,然后根据这 K 个实例的类别(分类问题)或数值(回归问题)来决定新实例的类别或值。
2. 使用 scikit - learn 实现 KNN 算法的步骤及详细解释
2.1 导入必要的库
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
numpy
:用于数值计算和数组操作。load_iris
:从sklearn.datasets
中导入,用于加载鸢尾花数据集,这是一个常用的分类数据集。train_test_split
:用于将数据集划分为训练集和测试集。KNeighborsClassifier
:scikit - learn
中实现 KNN 分类算法的类。accuracy_score
:用于评估分类模型的准确率。
2.2 加载和准备数据集
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data # 特征数据,包含花的各种测量值
y = iris.target # 标签数据,代表花的类别
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
load_iris()
:加载鸢尾花数据集,返回一个包含数据和标签的对象。X
:存储特征数据,每一行代表一个样本,每一列代表一个特征。y
:存储标签数据,每个元素对应一个样本的类别。train_test_split
:将数据集按test_size = 0.3
的比例划分为训练集(70%)和测试集(30%),random_state = 42
保证每次划分的结果一致,方便结果复现。
2.3 创建 KNN 分类器并训练模型
# 创建 KNN 分类器,设置 K 值为 3
knn = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn.fit(X_train, y_train)
KNeighborsClassifier(n_neighbors = 3)
:创建一个 KNN 分类器对象,n_neighbors
参数指定 K 值,这里设置为 3,表示在进行预测时考虑最近的 3 个邻居。fit(X_train, y_train)
:使用训练集数据X_train
和对应的标签y_train
对 KNN 模型进行训练。在 KNN 算法中,“训练” 过程实际上只是存储训练数据,因为预测时才会进行距离计算和邻居查找。
2.4 进行预测并评估模型
# 对测试集进行预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"KNN 模型的准确率: {accuracy}")
predict(X_test)
:使用训练好的 KNN 模型对测试集数据X_test
进行预测,返回预测的标签y_pred
。accuracy_score(y_test, y_pred)
:计算预测结果y_pred
与真实标签y_test
之间的准确率,即预测正确的样本数占总样本数的比例。
3. 场景示例
3.1 鸢尾花分类
上述代码示例就是一个典型的鸢尾花分类场景。鸢尾花数据集包含了鸢尾花的四个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)和三个类别(Setosa、Versicolour、Virginica)。通过 KNN 算法,我们可以根据花的特征来预测它所属的类别。
3.2 手写数字识别
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# 加载手写数字数据集
digits = load_digits()
X = digits.data # 特征数据,每个样本是一个 8x8 图像的像素值展开
y = digits.target # 标签数据,代表数字 0 - 9
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建 KNN 分类器,设置 K 值为 5
knn = KNeighborsClassifier(n_neighbors=5)
# 训练模型
knn.fit(X_train, y_train)
# 对测试集进行预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"手写数字识别 KNN 模型的准确率: {accuracy}")
# 可视化部分预测结果
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
axes = axes.flatten()
for i in range(10):
idx = np.random.randint(0, len(X_test))
image = X_test[idx].reshape(8, 8)
true_label = y_test[idx]
pred_label = y_pred[idx]
axes[i].imshow(image, cmap=plt.cm.gray_r)
axes[i].set_title(f"True: {true_label}, Pred: {pred_label}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
在这个手写数字识别场景中,我们使用 load_digits
加载手写数字数据集,每个样本是一个 8x8 的图像,将其像素值展开作为特征。通过 KNN 算法进行训练和预测,并可视化部分预测结果,直观地展示模型的分类效果。