(4) 评估模型的性能并调参:
更详细的可以查看萌弟大佬的知乎:https://zhuanlan.zhihu.com/p/140040705
import pandas as pd
import numpy as np
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature = iris.feature_names
data = pd.DataFrame(X,columns=feature)
data['target'] = y
data.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
# 使用网格搜索进行超参数调优:
# 方式1:网格搜索GridSearchCV()
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline # 引入管道简化学习流程
from sklearn.preprocessing import StandardScaler
import time
start_time = time.time()
pipe_svc = make_pipeline(StandardScaler(),SVC(random_state=1))
param_range = [0.0001,0.001,0.01,0.1,1.0,10.0,100.0,1000.0]
param_grid = [{'svc__C':param_range,'svc__kernel':['linear']},{'svc__C':param_range,'svc__gamma':param_range,'svc__kernel':['rbf']}]
gs = GridSearchCV(estimator=pipe_svc,param_grid=param_grid,scoring='accuracy',cv=10,n_jobs=-1)
gs = gs.fit(X,y)
end_time = time.time()
print("网格搜索经历时间:%.3f S" % float(end_time-start_time))
print(gs.best_score_)
print(gs.best_params_)
网格搜索经历时间:5.388 S
0.9800000000000001
{'svc__C': 1.0, 'svc__gamma': 0.1, 'svc__kernel': 'rbf'}
# 方式2:随机网格搜索RandomizedSearchCV()
from sklearn.model_selection import RandomizedSearchCV
from sklearn.svm import SVC
import time
start_time = time.time()
pipe_svc = make_pipeline(StandardScaler(),SVC(random_state=1))
param_range = [0.0001,0.001,0.01,0.1,1.0,10.0,100.0,1000.0]
param_grid = [{'svc__C':param_range,'svc__kernel':['linear']},{'svc__C':param_range,'svc__gamma':param_range,'svc__kernel':['rbf']}]
# param_grid = [{'svc__C':param_range,'svc__kernel':['linear','rbf'],'svc__gamma':param_range}]
gs = RandomizedSearchCV(estimator=pipe_svc, param_distributions=param_grid,scoring='accuracy',cv=10,n_jobs=-1)
gs = gs.fit(X,y)
end_time = time.time()
print("随机网格搜索经历时间:%.3f S" % float(end_time-start_time))
print(gs.best_score_)
print(gs.best_params_)
随机网格搜索经历时间:0.102 S
0.9733333333333334
{'svc__kernel': 'rbf', 'svc__gamma': 0.001, 'svc__C': 100.0}
# 没能用贝叶斯优化svc,这里就是用lgb来跑一下
import lightgbm as lgb
from bayes_opt import BayesianOptimization
from sklearn.model_selection import cross_val_score
train_X, train_y = None, None
def BayesianSearch(clf, params):
"""贝叶斯优化器"""
# 迭代次数
num_iter = 25
init_points = 5
# 创建一个贝叶斯优化对象,输入为自定义的模型评估函数与超参数的范围
bayes = BayesianOptimization(clf, params)
# 开始优化
bayes.maximize(init_points=init_points, n_iter=num_iter)
# 输出结果
print("Final result:", bayes.max)
return params
def GBM_evaluate(min_child_samples, min_child_weight, colsample_bytree, max_depth, subsample, reg_alpha, reg_lambda):
"""自定义的模型评估函数"""
# 模型固定的超参数
param = {
'objective': 'regression',
'n_estimators': 275,
'metric': 'rmse',
'random_state': 2018}
# 贝叶斯优化器生成的超参数
param['min_child_weight'] = int(min_child_weight)
param['colsample_bytree'] = float(colsample_bytree),
param['max_depth'] = int(max_depth),
param['subsample'] = float(subsample),
param['reg_lambda'] = float(reg_lambda),
param['reg_alpha'] = float(reg_alpha),
param['min_child_samples'] = int(min_child_samples)
# 5-flod 交叉检验,注意BayesianOptimization会向最大评估值的方向优化,因此对于回归任务需要取负数。
# 这里的评估函数为neg_mean_squared_error,即负的MSE。
val = cross_val_score(lgb.LGBMRegressor(**param),
train_X, train_y ,scoring='neg_mean_squared_error', cv=5).mean()
return val
if __name__ == '__main__':
# 获取数据,这里使用的是Kaggle比赛的数据
train_X, train_y = X, y
# 调参范围
adj_params = {'min_child_weight': (3, 20),
'colsample_bytree': (0.4, 1),
'max_depth': (5, 15),
'subsample': (0.5, 1),
'reg_lambda': (0.1, 1),
'reg_alpha': (0.1, 1),
'min_child_samples': (10, 30)}
# 调用贝叶斯优化
BayesianSearch(GBM_evaluate, adj_params)
| iter | target | colsam... | max_depth | min_ch... | min_ch... | reg_alpha | reg_la... | subsample |
-------------------------------------------------------------------------------------------------------------
| [0m 1 [0m | [0m-0.04982 [0m | [0m 0.8372 [0m | [0m 13.37 [0m | [0m 15.69 [0m | [0m 18.24 [0m | [0m 0.7529 [0m | [0m 0.1891 [0m | [0m 0.5799 [0m |
| [95m 2 [0m | [95m-0.0494 [0m | [95m 0.6598 [0m | [95m 14.47 [0m | [95m 16.88 [0m | [95m 18.61 [0m | [95m 0.6301 [0m | [95m 0.2553 [0m | [95m 0.7997 [0m |
| [0m 3 [0m | [0m-0.05979 [0m | [0m 0.6296 [0m | [0m 13.05 [0m | [0m 11.51 [0m | [0m 17.39 [0m | [0m 0.1482 [0m | [0m 0.614 [0m | [0m 0.9209 [0m |
| [0m 4 [0m | [0m-0.06427 [0m | [0m 0.4474 [0m | [0m 7.144 [0m | [0m 10.28 [0m | [0m 4.937 [0m | [0m 0.8855 [0m | [0m 0.7714 [0m | [0m 0.9347 [0m |
| [0m 5 [0m | [0m-0.05847 [0m | [0m 0.5935 [0m | [0m 5.444 [0m | [0m 19.74 [0m | [0m 9.531 [0m | [0m 0.8754 [0m | [0m 0.8536 [0m | [0m 0.6944 [0m |
| [0m 6 [0m | [0m-0.05176 [0m | [0m 1.0 [0m | [0m 10.81 [0m | [0m 19.11 [0m | [0m 20.0 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 7 [0m | [0m-0.05183 [0m | [0m 1.0 [0m | [0m 14.28 [0m | [0m 19.72 [0m | [0m 14.41 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 8 [0m | [0m-0.06693 [0m | [0m 0.519 [0m | [0m 15.0 [0m | [0m 26.45 [0m | [0m 19.97 [0m | [0m 0.1 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 9 [0m | [0m-0.05576 [0m | [0m 0.4 [0m | [0m 5.0 [0m | [0m 15.99 [0m | [0m 20.0 [0m | [0m 0.2854 [0m | [0m 1.0 [0m | [0m 0.5 [0m |
| [0m 10 [0m | [0m-0.07412 [0m | [0m 1.0 [0m | [0m 15.0 [0m | [0m 30.0 [0m | [0m 3.0 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 11 [0m | [0m-0.06855 [0m | [0m 0.4 [0m | [0m 5.0 [0m | [0m 30.0 [0m | [0m 16.61 [0m | [0m 0.1 [0m | [0m 1.0 [0m | [0m 1.0 [0m |
| [0m 12 [0m | [0m-0.05707 [0m | [0m 0.4 [0m | [0m 10.52 [0m | [0m 17.63 [0m | [0m 15.55 [0m | [0m 0.1 [0m | [0m 1.0 [0m | [0m 1.0 [0m |
| [0m 13 [0m | [0m-0.05279 [0m | [0m 1.0 [0m | [0m 15.0 [0m | [0m 16.63 [0m | [0m 8.128 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 14 [0m | [0m-0.07298 [0m | [0m 1.0 [0m | [0m 15.0 [0m | [0m 10.0 [0m | [0m 3.0 [0m | [0m 0.1 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 15 [0m | [0m-0.06029 [0m | [0m 0.4 [0m | [0m 15.0 [0m | [0m 15.65 [0m | [0m 13.26 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [95m 16 [0m | [95m-0.04908 [0m | [95m 1.0 [0m | [95m 15.0 [0m | [95m 21.39 [0m | [95m 8.975 [0m | [95m 1.0 [0m | [95m 1.0 [0m | [95m 1.0 [0m |
| [0m 17 [0m | [0m-0.05176 [0m | [0m 1.0 [0m | [0m 12.02 [0m | [0m 20.1 [0m | [0m 5.116 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [95m 18 [0m | [95m-0.0463 [0m | [95m 1.0 [0m | [95m 5.0 [0m | [95m 22.9 [0m | [95m 3.0 [0m | [95m 0.1 [0m | [95m 0.1 [0m | [95m 0.5 [0m |
| [0m 19 [0m | [0m-0.06587 [0m | [0m 0.4 [0m | [0m 5.0 [0m | [0m 27.41 [0m | [0m 3.0 [0m | [0m 0.1 [0m | [0m 1.0 [0m | [0m 1.0 [0m |
| [0m 20 [0m | [0m-0.05183 [0m | [0m 1.0 [0m | [0m 5.0 [0m | [0m 19.19 [0m | [0m 3.0 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 21 [0m | [0m-0.05243 [0m | [0m 0.4956 [0m | [0m 7.591 [0m | [0m 21.39 [0m | [0m 3.431 [0m | [0m 0.2133 [0m | [0m 0.6901 [0m | [0m 0.8303 [0m |
| [0m 22 [0m | [0m-0.0463 [0m | [0m 1.0 [0m | [0m 11.54 [0m | [0m 22.35 [0m | [0m 10.16 [0m | [0m 0.1 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 23 [0m | [0m-0.04786 [0m | [0m 1.0 [0m | [0m 14.01 [0m | [0m 24.81 [0m | [0m 11.77 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 24 [0m | [0m-0.05176 [0m | [0m 1.0 [0m | [0m 5.0 [0m | [0m 10.0 [0m | [0m 20.0 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 25 [0m | [0m-0.05574 [0m | [0m 0.4 [0m | [0m 10.79 [0m | [0m 23.68 [0m | [0m 13.0 [0m | [0m 1.0 [0m | [0m 1.0 [0m | [0m 0.5 [0m |
| [0m 26 [0m | [0m-0.04786 [0m | [0m 1.0 [0m | [0m 12.3 [0m | [0m 24.35 [0m | [0m 8.227 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 27 [0m | [0m-0.05282 [0m | [0m 1.0 [0m | [0m 5.0 [0m | [0m 10.0 [0m | [0m 14.19 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 28 [0m | [0m-0.04726 [0m | [0m 1.0 [0m | [0m 15.0 [0m | [0m 24.04 [0m | [0m 9.446 [0m | [0m 0.1 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 29 [0m | [0m-0.05176 [0m | [0m 1.0 [0m | [0m 9.834 [0m | [0m 14.18 [0m | [0m 20.0 [0m | [0m 1.0 [0m | [0m 0.1 [0m | [0m 0.5 [0m |
| [0m 30 [0m | [0m-0.0522 [0m | [0m 0.4 [0m | [0m 14.17 [0m | [0m 22.28 [0m | [0m 11.2 [0m | [0m 0.1 [0m | [0m 0.1 [0m | [0m 1.0 [0m |
=============================================================================================================
Final result: {'target': -0.046298141017010144, 'params': {'colsample_bytree': 1.0, 'max_depth': 5.0, 'min_child_samples': 22.89726011887252, 'min_child_weight': 3.0, 'reg_alpha': 0.1, 'reg_lambda': 0.1, 'subsample': 0.5}}
对于二分类问题可以绘制混淆矩阵和ROC曲线,这里就不粘贴文档了,可以试着手动实现一下
本月学习展示告一个段落,下个月将会学习集成学习。对于本次的小任务——“结合sklearn的fetch_lfw_people数据集,进行一次实战”,我在下次开课前另开文章完成吧!