鸢尾花数据集介绍
鸢尾花数据集包含了三个类别的鸢尾花样本:Setosa、Versicolor和Virginica。每个样本有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。我们的目标是通过这些特征来预测鸢尾花的类别。
K近邻算法简介
K近邻算法是一种简单而有效的分类算法。它的基本思想是:对于一个未知样本,通过计算其与训练集中所有样本的距离,选取距离最近的k个样本,然后根据这k个样本的类别来预测未知样本的类别。
超参数搜索
在K近邻算法中,k值是一个重要的超参数。不同的k值可能会导致模型性能的显著变化。因此,我们需要通过超参数搜索来找到最优的k值。这里我们将使用网格搜索(Grid Search)来进行超参数搜索。
网格搜索
网格搜索是一种通过遍历给定超参数的所有可能组合来确定最优超参数的方法。为了进行网格搜索,我们首先要定义搜索的k值范围,然后在这个范围内尝试所有可能的k值,然后评估每个k值对应的模型性能,并选择性能最优的k值。
1. 网格搜索的原理
网格搜索的核心思想非常简单:穷举搜索给定的超参数组合,通过交叉验证来评估每个组合的性能,最终选择表现最好的一组参数。这个过程类似于在一个参数的二维网格上搜索最优点,因此得名“网格搜索”。
假设我们有两个超参数需要调优:参数A和参数B,它们的候选取值分别为A1、A2和A3,B1、B2和B3。网格搜索会依次尝试以下参数组合:
- (A1, B1), (A1, B2), (A1, B3)
- (A2, B1), (A2, B2), (A2, B3)
- (A3, B1), (A3, B2), (A3, B3)
对于每个参数组合,我们使用交叉验证来评估模型在训练集上的性能。最后,选择性能最佳的一组超参数作为我们模型的最终选择。
2. 网格搜索的优势
网格搜索作为调参方法有以下几个优势:
- 全面性:网格搜索尝试了所有可能的参数组合,确保我们不会错过最优解。
- 直观:网格搜索简单直观,易于理解和实现。
- 可复现性:给定相同的超参数范围和步长,网格搜索的结果是可复现的。
然而,网格搜索也有其缺点,主要体现在计算成本方面。随着超参数数量的增加,搜索空间会呈指数级增长,导致网格搜索的计算复杂度急剧上升。因此,对于高维参数空间和大型数据集,网格搜索可能不是最优的选择。
代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 定义要搜索的k值范围
param_grid = {'n_neighbors': [1, 3, 5, 7, 9, 11, 13, 15]}
# 创建K近邻分类器
knn = KNeighborsClassifier()
# 初始化网格搜索对象
grid_search = GridSearchCV(knn, param_grid, cv=5)
# 在训练集上进行网格搜索
grid_search.fit(X_train, y_train)
# 方式1直接比较预测值和真实值
y_pred = grid_search.predict(X_test)
print(y_pred == y_test)
print("准确率:", sum(y_pred == y_test) / len(y_test))
# 方式2计算在测试集上的准确率
score = grid_search.score(X_test, y_test)
print("准确率:", score)
# 输出最优的k值和对应的准确率
print("最优的k值:", grid_search.best_params_['n_neighbors'])
print("最优的准确率:", grid_search.best_score_)
print("最优的模型:", grid_search.best_estimator_)
在这个示例代码中,我们使用了sklearn
库中的GridSearchCV
类进行网格搜索。我们指定了要搜索的k值范围为[1, 3, 5, 7, 9, 11, 13, 15],然后网格搜索会尝试这些k值,并返回最优的k值和对应的准确率。