【scikit-learn】sklearn.model_selection.GridSearchCV 类:网格搜索交叉验证 超参数优化

sklearn.model_selection.GridSearchCV

GridSearchCVsklearn.model_selection 提供的 超参数优化工具,用于 遍历所有可能的超参数组合,通过 交叉验证 选择 最佳超参数


1. GridSearchCV 作用

  • 自动搜索最佳超参数组合,提高模型性能。
  • 使用交叉验证(默认 cv=5 评估不同超参数的效果。
  • 适用于分类和回归任务,支持 不同评分指标accuracyf1roc_aucr2 等)。

2. GridSearchCV 代码示例

(1) 超参数优化 SVM

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 设定超参数搜索空间
param_grid = {
    "C": [0.1, 1, 10],  # 正则化参数
    "kernel": ["linear", "rbf"]  # 选择不同核函数
}

# 初始化 SVM 并进行网格搜索
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)
print("最佳得分:", grid_search.best_score_)

输出

最佳参数: {'C': 1, 'kernel': 'linear'}
最佳得分: 0.98

解释

  • GridSearchCV 遍历所有可能的 Ckernel 组合,选择最优超参数。

(2) 使用 StratifiedKFold 进行分层交叉验证

from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(n_splits=5)
grid_search = GridSearchCV(SVC(), param_grid, cv=cv)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)

解释

  • 适用于类别不均衡数据,保证交叉验证每折类别比例一致。

(3) 选择不同评分指标

grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="accuracy")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)

可选评分指标

任务评分指标 (scoring)说明
分类"accuracy"准确率
分类"f1"F1-score
分类"roc_auc"ROC AUC
回归"r2"R²(决定系数)
回归"neg_mean_absolute_error"负 MAE
回归"neg_mean_squared_error"负 MSE

示例:

grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="f1_macro")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)

(4) 训练多个模型

from sklearn.ensemble import RandomForestClassifier

param_grid = {
    "n_estimators": [10, 50, 100],
    "max_depth": [None, 10, 20]
}

grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5)
grid_search.fit(X, y)

print("最佳参数:", grid_search.best_params_)

解释

  • GridSearchCV 可用于多种模型,如 RandomForestClassifier

3. GridSearchCV 的参数

GridSearchCV(estimator, param_grid, scoring=None, cv=None, n_jobs=None, verbose=0)
参数说明
estimator评估器(模型),如 SVC()
param_grid需要搜索的超参数
scoring评分指标(如 "accuracy""f1""roc_auc""r2"
cv交叉验证策略(默认 5,可传 KFold()StratifiedKFold()
n_jobs并行计算(-1 表示使用所有 CPU 核心)
verbose是否打印搜索过程(0=不输出,1=简单输出,2=详细输出)

4. 适用场景

  • 超参数优化,提高模型性能
  • 分类/回归任务的模型调优
  • 结合 StratifiedKFold 处理类别不均衡数据

5. GridSearchCV vs. RandomizedSearchCV vs. train_test_split

方法适用情况作用
GridSearchCV参数范围较小,计算量可控遍历所有参数组合
RandomizedSearchCV参数范围较大随机选择部分参数搜索
train_test_split简单模型训练训练集/测试集划分

示例:

from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform

# 设定参数分布
param_dist = {'C': uniform(0.1, 10), 'kernel': ['linear', 'rbf']}

# 进行随机搜索
random_search = RandomizedSearchCV(SVC(), param_dist, n_iter=5, cv=5, random_state=42)
random_search.fit(X, y)

print("最佳参数:", random_search.best_params_)

解释

  • RandomizedSearchCV 随机选取参数组合,适用于 大参数空间

6. 结论

  • GridSearchCV 遍历所有超参数组合,通过 交叉验证选择最佳参数,适用于 分类和回归任务
  • 如果参数空间 较大,可使用 RandomizedSearchCV 进行随机搜索
  • 如果数据 类别不均衡,应结合 StratifiedKFold 进行分层交叉验证
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值