机器学习 | 基于Scikit-learn中手写数字集的交叉验证

本文详细介绍了交叉验证的概念,特别是在手写数字集上如何选择最佳参数以避免过拟合。通过Scikit-Learn示例展示了如何使用GridSearchCV对SVM进行参数搜索,以及交叉验证的优势和不足。
摘要由CSDN通过智能技术生成

在本文中,我们将讨论交叉验证及其在手写数字集上的使用。此外,我们将看到使用手写数字集的代码实现。

什么是交叉验证?

手写数字集的交叉验证将允许我们选择最佳参数,避免过度拟合训练数据集。它是一个试验的尝试程序,检查的交叉验证得分的每个参数,然后经过评估,选择最佳的程序。它也适用于商业工作流。

Scikit Learn中的Digits Dataset包含UCI ML手写数字数据集的副本。它是一个非常适合初学者的分类数据集,也是学习包括CNN在内的各种机器学习算法的良好数据集。

交叉验证是一种技术,我们使用数据集的子集训练模型,然后使用互补子集进行评估。交叉验证涉及的三个步骤如下:

  • 保留部分样本数据集。
  • 使用剩余的数据集训练模型。
  • 使用数据集的保留部分测试模型。

K折交叉验证:在这种方法中,我们将数据集分成k个子集(称为折叠),然后对所有子集进行训练,但留下一个(k-1)子集来评估训练模型。在这种方法中,我们使用每次为测试目的保留的不同子集进行重复测试。

要执行K折交叉验证,我们可以使用cross_val_score方法来执行验证。下面是语法:

cross_val_score(model, X, y, cv=5)

model:它是我们想要拟合数据的估计器。
X:是训练数据。
y:是标签的数量。
cv:表示(分层)K折叠中的折叠数。

我们可以使用GridSearchCV,它在我们将要执行的参数网格上执行穷举搜索。它接受以下参数:

GridSearchCV(model, param_grid, cv=kf, scoring=‘accuracy’)

model:它是我们想要拟合数据的估计器。
param_grid:它将在提供的所有参数值组合上运行
cv:交叉验证策略。
scoring:它定义了评估交叉验证模型在测试集上的性能的策略。

对数据集执行K折交叉验证

步骤1:导入库

导入进一步步骤所需的所有必要库。这段python代码演示了如何使用scikit-learn库执行网格搜索来调整支持向量机(SVM)分类器的超参数。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.datasets import load_digits
from sklearn.svm import SVC

步骤2:加载数据集

手写数字集通过load_digits(return_X_y=True)函数加载到该行中,该函数还将特征矩阵分配给X,并将关联的标签分配给y。

X, y = load_digits(return_X_y=True)

步骤3:使用numpy logspace定义网格参数

param_grid = {'C':  np.logspace(-5, 5, 10)}

步骤4:在Sklearn中定义Kfold对象,并创建SVM分类器

此代码片段创建了一个KFold交叉验证对象,并使用sigmoid内核实例化了一个支持向量机(SVM)分类器。

svm = SVC(kernel="sigmoid")
kf = KFold(n_splits=5, shuffle=True, random_state=42)

步骤5:现在我们需要使用交叉验证和SVM执行GridSearchCV

# performing exhaustive search
grid_search = GridSearchCV(svm, param_grid, cv=kf, scoring='accuracy', return_train_score=True, verbose=3, n_jobs=-1)

grid_search.fit(X,y)

步骤6:绘制并打印结果

通过网格搜索找到的平均交叉验证分数、标准差和最佳超参数值绘制在此代码片段中。

scores_avg = grid_search.cv_results_['mean_test_score']
scores_std = grid_search.cv_results_['std_test_score']
param_values = grid_search.cv_results_['param_C']
# Do the plotting
plt.figure()
plt.semilogx(param_values, scores_avg)
plt.semilogx(param_values, np.array(scores_avg) + np.array(scores_std), "r--")
plt.semilogx(param_values, np.array(scores_avg) - np.array(scores_std), "g--")
locs, labels = plt.yticks()
plt.yticks(locs, list(map(lambda x: "%g" % x, locs)))
plt.ylabel("CV score")
plt.xlabel("Parameter C")
plt.ylim(0, 1.1)
plt.show()


# Print the best score and parameters
print('Best score:', grid_search.best_score_)
print('Best C:', grid_search.best_params_['C'])

输出

Best score: 0.9115242958836273
Best C: 0.2782559402207126

在这里插入图片描述
解释:我们看到分数在10-2之后增加,为我们的支持向量机分类器提供了更好的参数C。最终在100之后,我们再次看到一个下降,然后几乎是一个恒定的值。这里的CV分数越高,情况越好。

交叉验证的优点和缺点

优点

  • 它提供了模型如何在未知数据上泛化的想法。
  • 它有助于估计模型预测的准确估计。
  • 交叉验证通过对未知数据提供更可靠的模型性能估计,有助于防止过拟合。
  • 它可以用来优化模型的超参数

缺点

  • 交叉验证需要更长的训练时间,因为我们多次拆分数据。例如,如果有5个折叠,并且我们的参数组合等于10,那么总共将有50次分裂和训练。当增加另一个参数时,它呈指数增长。
  • 它需要巨大的处理能力。
  • 交叉验证中折叠次数的选择会影响偏差-方差权衡,即,太少的折叠可能导致高方差,而太多的折叠可能导致高偏差
  • 12
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值