sklearn.model_selection.GridSearchCV
GridSearchCV
是 sklearn.model_selection
提供的 超参数优化工具,用于 遍历所有可能的超参数组合,通过 交叉验证 选择 最佳超参数。
1. GridSearchCV
作用
- 自动搜索最佳超参数组合,提高模型性能。
- 使用交叉验证(默认
cv=5
) 评估不同超参数的效果。 - 适用于分类和回归任务,支持 不同评分指标(
accuracy
、f1
、roc_auc
、r2
等)。
2. GridSearchCV
代码示例
(1) 超参数优化 SVM
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris
# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
# 设定超参数搜索空间
param_grid = {
"C": [0.1, 1, 10], # 正则化参数
"kernel": ["linear", "rbf"] # 选择不同核函数
}
# 初始化 SVM 并进行网格搜索
grid_search = GridSearchCV(SVC(), param_grid, cv=5)
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
print("最佳得分:", grid_search.best_score_)
输出
最佳参数: {'C': 1, 'kernel': 'linear'}
最佳得分: 0.98
解释
GridSearchCV
遍历所有可能的C
和kernel
组合,选择最优超参数。
(2) 使用 StratifiedKFold
进行分层交叉验证
from sklearn.model_selection import StratifiedKFold
cv = StratifiedKFold(n_splits=5)
grid_search = GridSearchCV(SVC(), param_grid, cv=cv)
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
解释
- 适用于类别不均衡数据,保证交叉验证每折类别比例一致。
(3) 选择不同评分指标
grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="accuracy")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
可选评分指标
任务 | 评分指标 (scoring ) | 说明 |
---|---|---|
分类 | "accuracy" | 准确率 |
分类 | "f1" | F1-score |
分类 | "roc_auc" | ROC AUC |
回归 | "r2" | R²(决定系数) |
回归 | "neg_mean_absolute_error" | 负 MAE |
回归 | "neg_mean_squared_error" | 负 MSE |
示例:
grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring="f1_macro")
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
(4) 训练多个模型
from sklearn.ensemble import RandomForestClassifier
param_grid = {
"n_estimators": [10, 50, 100],
"max_depth": [None, 10, 20]
}
grid_search = GridSearchCV(RandomForestClassifier(), param_grid, cv=5)
grid_search.fit(X, y)
print("最佳参数:", grid_search.best_params_)
解释
GridSearchCV
可用于多种模型,如RandomForestClassifier
。
3. GridSearchCV
的参数
GridSearchCV(estimator, param_grid, scoring=None, cv=None, n_jobs=None, verbose=0)
参数 | 说明 |
---|---|
estimator | 评估器(模型),如 SVC() |
param_grid | 需要搜索的超参数 |
scoring | 评分指标(如 "accuracy" 、"f1" 、"roc_auc" 、"r2" ) |
cv | 交叉验证策略(默认 5 ,可传 KFold() 、StratifiedKFold() ) |
n_jobs | 并行计算(-1 表示使用所有 CPU 核心) |
verbose | 是否打印搜索过程(0 =不输出,1 =简单输出,2 =详细输出) |
4. 适用场景
- 超参数优化,提高模型性能。
- 分类/回归任务的模型调优。
- 结合
StratifiedKFold
处理类别不均衡数据。
5. GridSearchCV
vs. RandomizedSearchCV
vs. train_test_split
方法 | 适用情况 | 作用 |
---|---|---|
GridSearchCV | 参数范围较小,计算量可控 | 遍历所有参数组合 |
RandomizedSearchCV | 参数范围较大 | 随机选择部分参数搜索 |
train_test_split | 简单模型训练 | 训练集/测试集划分 |
示例:
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform
# 设定参数分布
param_dist = {'C': uniform(0.1, 10), 'kernel': ['linear', 'rbf']}
# 进行随机搜索
random_search = RandomizedSearchCV(SVC(), param_dist, n_iter=5, cv=5, random_state=42)
random_search.fit(X, y)
print("最佳参数:", random_search.best_params_)
解释
RandomizedSearchCV
随机选取参数组合,适用于 大参数空间。
6. 结论
GridSearchCV
遍历所有超参数组合,通过 交叉验证选择最佳参数,适用于 分类和回归任务。- 如果参数空间 较大,可使用
RandomizedSearchCV
进行随机搜索。 - 如果数据 类别不均衡,应结合
StratifiedKFold
进行分层交叉验证。