1. GridSearchCV 简介GridSearchCV:原理、应用与实例

GridSearchCV 是 scikit-learn 库中用于超参数调优的一种方法。它通过对预定义的参数网格进行穷举搜索,并利用交叉验证来评估每组参数组合的表现,从而帮助我们找到模型的最佳超参数配置。


2. 工作原理

  • 参数网格定义: 用户需要预先设定一个包含多个参数组合的字典,每个参数对应多个可能的取值。
  • 穷举搜索: GridSearchCV 会对每种可能的参数组合进行训练。
  • 交叉验证: 对每个参数组合,GridSearchCV 会利用交叉验证来评估模型性能,确保结果的稳定性和泛化能力。
  • 最佳参数选择: 最终,它会选择在交叉验证中表现最好的参数组合,并提供对应的模型。

3. 示例代码

以下是一个使用 GridSearchCV 调整支持向量机 (SVM) 模型超参数的示例:

from sklearn import datasets
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.svm import SVC

# 加载数据集
iris = datasets.load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 定义要搜索的参数网格
param_grid = {
    'C': [0.1, 1, 10],
    'kernel': ['linear', 'rbf'],
    'gamma': ['scale', 'auto']
}

# 创建 SVC 模型
svc = SVC()

# 实例化 GridSearchCV
grid_search = GridSearchCV(estimator=svc, param_grid=param_grid, cv=5, scoring='accuracy')

# 在训练集上执行搜索
grid_search.fit(X_train, y_train)

# 输出最佳参数和对应得分
print("最佳参数:", grid_search.best_params_)
print("最佳交叉验证准确率:", grid_search.best_score_)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.

这个示例展示了如何通过 GridSearchCV 对 SVM 模型进行超参数调优,并利用 5 折交叉验证来评估每个参数组合的效果。


4. 更多示例

示例 2:调优决策树模型
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.tree import DecisionTreeClassifier

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 定义参数网格
param_grid = {
    'max_depth': [3, 5, 7, None],
    'min_samples_split': [2, 5, 10]
}

# 创建决策树分类器
dtree = DecisionTreeClassifier(random_state=42)

# 实例化 GridSearchCV
grid_search = GridSearchCV(estimator=dtree, param_grid=param_grid, cv=5, scoring='accuracy')

# 搜索最佳参数
grid_search.fit(X_train, y_train)

# 打印最佳参数及得分
print("最佳参数:", grid_search.best_params_)
print("最佳交叉验证准确率:", grid_search.best_score_)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
示例 3:调优随机森林模型
from sklearn.ensemble import RandomForestClassifier

# 定义参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20],
    'min_samples_split': [2, 5]
}

rf = RandomForestClassifier(random_state=42)
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)

print("最佳参数:", grid_search.best_params_)
print("最佳交叉验证准确率:", grid_search.best_score_)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

5. 总结

GridSearchCV 是一种强大的工具,它可以自动化超参数搜索过程,通过交叉验证来评估每个参数组合的性能,从而帮助我们提高模型的表现。它适用于各种机器学习模型,但在参数组合较多时可能计算开销较大,此时可以考虑使用 RandomizedSearchCV 来降低计算负担。