《scikit-learn》通过GridSearchCV来进行超参数优化

微调的一种方法是手工调制超参数,直到找到一个好的超参数组合,这么做的话会非常冗长,你也可能没有时间探索多种组合,所以可以使用Scikit-Learn的GridSearchCV来做这项搜索工作。

GridSearchCV的名字其实可以拆分为两部分,GridSearch和CV,即网格搜索和交叉验证。这两个名字都非常好理解。网格搜索,搜索的是参数,即在指定的参数范围内,按步长依次调整参数,利用调整的参数训练学习器,从所有的参数中找到在验证集上精度最高的参数,这其实是一个训练和比较的过程。
GridSearchCV可以保证在指定的参数范围内找到精度最高的参数,但是这也是网格搜索的缺陷所在,他要求遍历所有可能参数的组合,在面对大数据集和多参数的情况下,非常耗时。
以分类树作为例子来说明其用法。

from sklearn.model_selection import GridSearchCV
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

wine = load_wine()  # 178个数据,13个属性,三个分类种类。
# 把红酒数据进行切分,切分成训练集和测试集合,切分比例一般是7:3
data_train, data_test, target_train, target_test = train_test_split(wine.data, wine.target, test_size=0.3)

# 建立一些超参数的网格,会自动帮我们取选择最好的搭配
param_grid = [
    {'criterion': ['entropy'], 'max_features': [2, 4, 6, 8], 'max_depth': [5, 6, 7, 8],
     'min_impurity_decrease': [0.01, 0.02]},  # 一共有 1x4x4x2 个超参数组合
    {'criterion': ['gini'], 'max_features': [3, 5, 7, 8], 'max_depth': [4, 5, 6, 7],
     'min_impurity_decrease': [0.01, 0.02]},  # 一共有 1x4x4x2 个超参数组合
]  # 一共有 1x4x4x2 + 1x4x4x2 个超参数组合

dtc = tree.DecisionTreeClassifier(criterion='entropy', class_weight='balanced')  # 定义一个决策树的实例,决策树节点的分裂选择的是 根据信息熵划分
grid_search = GridSearchCV(dtc, param_grid, cv=5,
                           scoring='accuracy')  # GridSearchCV的名字其实可以拆分为两部分,GridSearch和CV,即网格搜索和交叉验证。

# 使用多种参数组合的网格来进行训练。
clf = grid_search.fit(data_train, target_train)  # 这一步也是最耗费时间的。
score = grid_search.score(data_test, target_test)
print(score)  # 得到最后的评价分数

print(grid_search.best_params_)  # 找到最优的超参数,这里帮我自动选择好了最优的参数列表
print(grid_search.best_estimator_)  # 找到最优化的模型,甚至把初始化参数都打印出来了

# 找到了最好的那个模型
estimator = grid_search.best_estimator_
estimator.fit(data_train, target_train)
score = estimator.score(data_test, target_test)
print(score)
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值