sparkml_实战全流程_LogisticRegression(二)

交叉验证
网格搜索

参考:
https://www.jianshu.com/p/20456b512fa7


import pyspark.ml.tuning as tune
# 超参调优:grid search和train-validation splitting # 网格搜索
import pyspark.ml.tuning as tune
​
logistic = cl.LogisticRegression(labelCol='INFANT_ALIVE_AT_REPORT')
grid = tune.ParamGridBuilder()\
    .addGrid(logistic.maxIter, [5,10,50])\
    .addGrid(logistic.regParam, [0.01,0.05,0.3])\
    .build()# 找出模型之间比较的方法
evaluator = ev.BinaryClassificationEvaluator(
    rawPredictionCol='probability',
    labelCol='INFANT_ALIVE_AT_REPORT'
)# 使用K-Fold交叉验证评估各种参数的模型
cv = tune.CrossValidator(
    estimator=logistic,
    estimatorParamMaps=grid,
    evaluator=evaluator,
    numFolds=3
)# 我们不能直接使用数据,所以我们
# 创建一个构建特征的pipeline
pipeline = Pipeline(stages=[encoder, featuresCreator])
birth_train, birth_test = births.randomSplit([0.7,0.3],seed=123) # 重新打开数据进行处理
data_transformer = pipeline.fit(birth_train)
data_test = data_transformer.transform(birth_test)
​
​
# cvModel 返回估计的最佳模型  
# 寻找模型最佳参数组合
​
cvModel = cv.fit(data_transformer.transform(birth_train))
results = cvModel.transform(data_test)# 查看效果
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderROC'}))
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderPR'}))0.735848884034915
0.6959036715961695
# 使用下面的代码可以查看模型最佳参数:
# 查看最佳模型参数
results = [
    (
        [
            {key.name: paramValue} 
            for key, paramValue 
            in zip(
                params.keys(), 
                params.values())
        ], metric
    ) 
    for params, metric 
    in zip(
        cvModel.getEstimatorParamMaps(), 
        cvModel.avgMetrics
    )
]sorted(results, 
       key=lambda el: el[1], 
       reverse=True)[0]# 或者
param_maps = cvModel.getEstimatorParamMaps()
eval_metrics = cvModel.avgMetrics
​
param_res = []for params, metric in zip(param_maps, eval_metrics):
    param_metric = {}
    for key, param_val in zip(params.keys(), params.values()):
        param_metric[key.name]=param_val
    param_res.append((param_metric, metric))sorted(param_res, key=lambda x:x[1], reverse=True)[({'maxIter': 50, 'regParam': 0.01}, 0.7406291618177623),
 ({'maxIter': 10, 'regParam': 0.01}, 0.735580969909943),
 ({'maxIter': 50, 'regParam': 0.05}, 0.7355100622938429),
 ({'maxIter': 10, 'regParam': 0.05}, 0.7351586303619441),
 ({'maxIter': 10, 'regParam': 0.3}, 0.7248698034708339),
 ({'maxIter': 50, 'regParam': 0.3}, 0.7214679272915997),
 ({'maxIter': 5, 'regParam': 0.3}, 0.7180255703028883),
 ({'maxIter': 5, 'regParam': 0.01}, 0.7179304617840288),
 ({'maxIter': 5, 'regParam': 0.05}, 0.7173397593133481)]
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值