Sklearn-GridSearchCV网格搜索

本文介绍GridSearchCV在机器学习中的应用,包括参数设置、常见用法及实例演示。此外还对比了坐标下降法,适用于大数据集的参数调优。

GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数。但是这个方法适合于小数据集,一旦数据的量级上去了,很难得出结果。这个时候就是需要动脑筋了。数据量比较大的时候可以使用一个快速调优的方法——坐标下降它其实是一种贪心算法:拿当前对模型影响最大的参数调优,直到最优化再拿下一个影响最大的参数调优,如此下去,直到所有的参数调整完毕。这个方法的缺点就是可能会调到局部最优而不是全局最优,但是省时间省力,巨大的优势面前,还是试一试吧,后续可以再拿bagging再优化。

 

回到sklearn里面的GridSearchCVGridSearchCV用于系统地遍历多种参数组合,通过交叉验证确定最佳效果参数

GridSearchCVsklearn官方网址:http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV

classsklearn.model_selection.GridSearchCV(estimator,param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True,cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise',return_train_score=True)

 

  1. 常用参数解读

estimator:所使用的分类器,如estimator=RandomForestClassifier(min_samples_split=100,min_samples_leaf=20,max_depth=8,max_features='sqrt',random_state=10), 并且传入除需要确定最佳的参数之外的其他参数。每一个分类器都需要一个scoring参数,或者score方法。

param_grid:值为字典或者列表,即需要最优化的参数的取值,param_grid =param_test1,param_test1 = {'n_estimators':range(10,71,10)}。

scoring :准确度评价标准,默认None,这时需要使用score函数;或者如scoring='roc_auc',根据所选模型不同,评价准则不同。字符串(函数名),或是可调用对象,需要其函数签名形如:scorer(estimator, X, y);如果是None,则使用estimator的误差估计函数

cv :交叉验证参数,默认None,使用三折交叉验证。指定fold数量,默认为3,也可以是yield训练/测试数据的生成器

refit :默认为True,程序将会以交叉验证训练集得到的最佳参数,重新对所有可用的训练集与开发集进行,作为最终用于性能评估的最佳模型参数。即在搜索参数结束后,用最佳参数结果再次fit一遍全部数据集

iid:默认True,True时,默认为各个样本fold概率分布一致,误差估计为所有样本之和,而非各个fold的平均

verbose:日志冗长度,int:冗长度,0:不输出训练过程,1:偶尔输出,>1:对每个子模型都输出。

n_jobs: 并行数,int:个数,-1:跟CPU核数一致, 1:默认值。

pre_dispatch:指定总共分发的并行任务数。当n_jobs大于1时,数据将在每个运行点进行复制,这可能导致OOM,而设置pre_dispatch参数,则可以预先划分总共的job数量,使数据最多被复制pre_dispatch次

  1. 进行预测的常用方法和属性

grid.fit():运行网格搜索

grid_scores_:给出不同参数情况下的评价结果

best_params_描述了已取得最佳结果的参数的组合

best_score_:成员提供优化过程期间观察到的最好的评分

  1. 网格搜索实例

param_test1 ={'n_estimators':range(10,71,10)}
gsearch1= GridSearchCV(estimator =RandomForestClassifier(min_samples_split=100,
                                 min_samples_leaf=20,max_depth=       8,max_features='sqrt',random_state=10), 
                       param_grid =param_test1,scoring='roc_auc',cv=5)
gsearch1.fit(X,y)
gsearch1.grid_scores_, gsearch1.best_params_, gsearch1.best_score_

输出结果如下:

([mean: 0.80681, std:0.02236, params: {'n_estimators': 10},

  mean: 0.81600, std: 0.03275, params:{'n_estimators': 20},

  mean: 0.81818, std: 0.03136, params:{'n_estimators': 30},

  mean: 0.81838, std: 0.03118, params:{'n_estimators': 40},

  mean: 0.82034, std: 0.03001, params:{'n_estimators': 50},

  mean: 0.82113, std: 0.02966, params:{'n_estimators': 60},

  mean: 0.81992, std: 0.02836, params:{'n_estimators': 70}],

{'n_estimators': 60},

0.8211334476626017)

 

如果有transform,使用Pipeline简化系统搭建流程,将transform与分类器串联起来(Pipelineof transforms with a final estimator)

pipeline= Pipeline([("features", combined_features), ("svm", svm)])
param_grid= dict(features__pca__n_components=[1, 2, 3],
                  features__univ_select__k=[1,2],
                  svm__C=[0.1, 1, 10])
 
grid_search= GridSearchCV(pipeline, param_grid=param_grid, verbose=10)
grid_search.fit(X,y)
print(grid_search.best_estimator_)


<think>首先,用户的问题是关于在PyCharm中使用scikit-learn的GridSearchCV。内容包括使用方法、示例和注意事项。用户提供了几个引用,但明确指出这些仅供参考,并非用户真正的需求,我应该参考用户上一次的问题和我的回答。用户上一次的问题是:"我想了解在PyCharm中使用scikit-learn(sklearn)的GridSearchCV的相关内容",所以我会专注于这个主题。 系统级指令: - 所有行内数学表达式必须使用$...$格式,例如变量、方程等。 - 独立公式必须使用$$...$$格式并单独成段。 - LaTeX语法正确。 - 使用中文回答。 - 生成相关问题。 - 回答中引用的段落末尾自然地添加引用标识,如[^1]。 用户可见层指令: - 回答结构清晰,帮助用户逐步解决问题。 - 尽量保证回答真实可靠。 从引用中提取相关信息: - 引用[1]提到sklearn中使用决策树和交叉熵验证的示例,提到了cross_val_score。 - 引用[2]提到在成长股分析中使用pandas、numpy、scikit-learn。 - 引用[3]提到Keras库,但与GridSearchCV无关。 用户说这些引用仅供参考,所以我会在回答中自然地引用它们,如果需要的话,但主要焦点是GridSearchCV。 关键主题:GridSearchCV是scikit-learn中的一个工具,用于超参数调优。它通过交叉验证来搜索最佳参数组合。 回答结构: 1. **介绍GridSearchCV**:解释它是什么、为什么使用它。 2. **使用方法**:一步步说明如何在PyCharm中设置和使用GridSearchCV- 包括安装scikit-learn(如果需要)。 - 导入必要库。 - 定义参数网格。 - 创建一个GridSearchCV对象。 - 拟合模型。 - 获取最佳参数和模型评分。 3. **示例代码**:提供一个完整的示例,使用一个简单的数据集,比如Iris或Diabetes数据集。 - 代码块使用Markdown格式。 - 确保语法正确。 4. **注意事项**:讨论常见问题,如计算成本、参数选择、交叉验证设置等。 5. **生成相关问题**:在回答结束时,添加几个相关问题。 在回答中,自然地添加引用标识。例如,如果引用中提到相关内容,在末尾加[^1],但用户说引用仅供参考,所以我会适度使用。 确保行内数学表达式:由于主题是机器学习,可能会有损失函数等数学表达式。例如,提到均方误差时,用$MSE$。 例如: - 在解释评分时,可以用$R^2$ score。 - 独立公式:如果需要,比如定义一个公式,但这里可能不多。 步骤: 1. **导入库**:from sklearn.model_selection import GridSearchCV 2. **选择模型和参数网格** 3. **拟合和评估** 示例数据集:使用Iris分类或Diabetes回归,简单易懂。 PyCharm特定的部分:PyCharm是一个IDE,所以提到如何设置环境、运行脚本等,但核心是sklearn,所以重点在代码。 生成相关问题: - 例如: 1. 如何优化GridSearchCV的计算时间? 2. GridSearchCV与RandomizedSearchCV的区别? 3. 在处理大数据集时使用GridSearchCV的注意事项? 完整回答大纲: - 开头:简述GridSearchCV- 分段解释方法。 - 提供示例代码。 - 注意事项。 - 相关问题。 使用中文回答。 </think>### 在PyCharm中使用scikit-learn的GridSearchCV 在本回答中,我将详细介绍在PyCharm IDE中使用scikit-learn(sklearn)的GridSearchCV进行超参数调优的方法、完整示例、注意事项及相关问题。GridSearchCVsklearn中的一个核心工具,它通过交叉验证自动搜索最佳参数组合,适用于分类和回归模型优化[^1]。PyCharm作为Python IDE,便于代码编写、调试和运行,但核心逻辑与普通Python脚本一致。下面我将逐步引导您完成整个过程。 #### 1. GridSearchCV的基本概念 GridSearchCV用于系统地遍历一组预定义的超参数组合(即“网格”),并使用交叉验证评估每个组合的性能。目标是找到使模型性能最优的参数。例如,对于一个决策树模型,可能搜索`max_depth`和`min_samples_split`的最优值。常用指标包括准确率(准确率公式为$\text{accuracy} = \frac{\text{正确预测数}}{\text{总样本数}}$)或均方误差($MSE$)。其优势在于自动化参数搜索,但计算成本较高。 #### 2. 在PyCharm中使用GridSearchCV的步骤 在PyCharm中,首先确保环境配置正确: - **安装依赖**:通过PyCharm的Terminal或pip安装scikit-learn、numpy和pandas(引用[2]提到这些库是基础)。运行: ```bash pip install scikit-learn numpy pandas ``` - **创建项目**:新建Python文件,导入所需库。 - **核心步骤**: 1. 导入必要模块。 2. 加载数据集(例如sklearn内置数据集)。 3. 定义参数网格(需搜索的超参数范围)。 4. 创建GridSearchCV对象,指定模型、参数网格、交叉验证折数等。 5. 拟合模型(自动执行搜索和交叉验证)。 6. 获取最佳参数和模型评分。 #### 3. 完整示例代码 以下是使用Iris数据集(分类问题)的示例代码。代码可直接在PyCharm中运行。我们选择一个简单的分类模型(如决策树),并使用GridSearchCV调优参数。 ```python # -*- coding: utf-8 -*- from sklearn.model_selection import GridSearchCV from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import pandas as pd # 加载数据集 iris = load_iris() X = iris.data # 特征数据 y = iris.target # 目标变量 # 划分训练集和测试集(80%训练,20%测试) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 定义决策树模型 dt_classifier = DecisionTreeClassifier(random_state=42) # 设置参数网格:搜索max_depth和min_samples_split的不同组合 param_grid = { 'max_depth': [3, 5, 7], # 树的最大深度 'min_samples_split': [2, 5, 10] # 节点分裂所需最小样本数 } # 创建GridSearchCV对象 # scoring="accuracy"表示使用准确率评估,cv=5表示5折交叉验证 grid_search = GridSearchCV(estimator=dt_classifier, param_grid=param_grid, scoring='accuracy', cv=5) # 拟合模型:在训练集上执行网格搜索 grid_search.fit(X_train, y_train) # 输出最佳参数和最佳评分 print("最佳参数:", grid_search.best_params_) print("最佳交叉验证准确率:", grid_search.best_score_) # 评估在测试集上的性能 best_model = grid_search.best_estimator_ test_accuracy = best_model.score(X_test, y_test) print("测试集准确率:", test_accuracy) # 查看所有参数组合的评分结果 results_df = pd.DataFrame(grid_search.cv_results_) print("参数搜索详细结果:") print(results_df[['params', 'mean_test_score']]) ``` **代码说明**: - **数据集**:使用sklearn的`load_iris()`加载Iris数据集,适合初学者。 - **参数网格**:`param_grid`定义了要搜索的超参数范围。例如,`max_depth`有三个候选值。 - **GridSearchCV对象**:`estimator`指定基础模型,`scoring`设置评估指标(这里用准确率),`cv=5`表示5折交叉验证。 - **输出结果**:`best_params_`给出最优参数,`best_score_`给出交叉验证平均分。最后使用最佳模型在测试集上验证。 - **运行在PyCharm**:新建Python文件,粘贴代码;点击运行按钮(或Ctrl+Shift+F10),结果将在Console输出。 #### 4. 注意事项 使用GridSearchCV时,需注意以下几点: - **计算成本高**:参数组合增多会指数级增加计算时间。例如,如果参数网格有10个值,cv=5,则需训练50个模型。建议从小网格开始测试(如上例)。 - **选择合适的评估指标**:分类问题常用`scoring="accuracy"`或`scoring="f1"`;回归问题可用`scoring="neg_mean_squared_error"`(引用[1]提到负均方误差)。确保指标与问题类型匹配。 - **随机种子设置**:使用`random_state`确保结果可复现(如`DecisionTreeClassifier(random_state=42)`)。 - **数据和预处理**:确保数据已标准化或归一化(如使用`StandardScaler`),避免参数搜索偏差。大数据集可考虑采样或增量学习。 - **交叉验证设置**:`cv`参数默认5折,但可调整。对于小数据集,增加折数(如cv=10)获得更稳健评估;大数据集可减少折数以加速计算。 - **并行化加速**:设置`n_jobs=-1`使用所有CPU核心加速搜索(添加到GridSearchCV参数中),但需PyCharm环境支持多线程。 - **模型选择**:GridSearchCV适用于任何sklearn模型,如SVM或随机森林(引用[1]涉及决策树)。引用[2]提到的pandas可用于结果分析,如示例中的`pd.DataFrame`。 #### 5. 相关问题 以下是基于本主题的延伸问题,供您进一步学习: 1. GridSearchCV与RandomizedSearchCV有何区别?哪个更适合大型参数空间? 2. 如何在高计算成本下优化GridSearchCV的性能(例如使用GPU或分布式计算)? 3. 在PyCharm中如何可视化GridSearchCV的结果(如使用matplotlib绘图显示参数与评分关系)? 4. GridSearchCV在处理不平衡数据集时有哪些常见问题和解决方案?
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值