交叉验证
网格搜索
参考:
https://www.jianshu.com/p/20456b512fa7
import pyspark.ml.tuning as tune
# 超参调优:grid search和train-validation splitting
# 网格搜索
import pyspark.ml.tuning as tune
logistic = cl.LogisticRegression(labelCol='INFANT_ALIVE_AT_REPORT')
grid = tune.ParamGridBuilder()\
.addGrid(logistic.maxIter, [5,10,50])\
.addGrid(logistic.regParam, [0.01,0.05,0.3])\
.build()
# 找出模型之间比较的方法
evaluator = ev.BinaryClassificationEvaluator(
rawPredictionCol='probability',
labelCol='INFANT_ALIVE_AT_REPORT'
)
# 使用K-Fold交叉验证评估各种参数的模型
cv = tune.CrossValidator(
estimator=logistic,
estimatorParamMaps=grid,
evaluator=evaluator,
numFolds=3
)
# 我们不能直接使用数据,所以我们
# 创建一个构建特征的pipeline
pipeline = Pipeline(stages=[encoder, featuresCreator])
birth_train, birth_test = births.randomSplit([0.7,0.3],seed=123) # 重新打开数据进行处理
data_transformer = pipeline.fit(birth_train)
data_test = data_transformer.transform(birth_test)
# cvModel 返回估计的最佳模型
# 寻找模型最佳参数组合
cvModel = cv.fit(data_transformer.transform(birth_train))
results = cvModel.transform(data_test)
# 查看效果
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderROC'}))
print(evaluator.evaluate(results, {evaluator.metricName:'areaUnderPR'}))
0.735848884034915
0.6959036715961695
# 使用下面的代码可以查看模型最佳参数:
# 查看最佳模型参数
results = [
(
[
{key.name: paramValue}
for key, paramValue
in zip(
params.keys(),
params.values())
], metric
)
for params, metric
in zip(
cvModel.getEstimatorParamMaps(),
cvModel.avgMetrics
)
]
sorted(results,
key=lambda el: el[1],
reverse=True)[0]
# 或者
param_maps = cvModel.getEstimatorParamMaps()
eval_metrics = cvModel.avgMetrics
param_res = []
for params, metric in zip(param_maps, eval_metrics):
param_metric = {}
for key, param_val in zip(params.keys(), params.values()):
param_metric[key.name]=param_val
param_res.append((param_metric, metric))
sorted(param_res, key=lambda x:x[1], reverse=True)
[({'maxIter': 50, 'regParam': 0.01}, 0.7406291618177623),
({'maxIter': 10, 'regParam': 0.01}, 0.735580969909943),
({'maxIter': 50, 'regParam': 0.05}, 0.7355100622938429),
({'maxIter': 10, 'regParam': 0.05}, 0.7351586303619441),
({'maxIter': 10, 'regParam': 0.3}, 0.7248698034708339),
({'maxIter': 50, 'regParam': 0.3}, 0.7214679272915997),
({'maxIter': 5, 'regParam': 0.3}, 0.7180255703028883),
({'maxIter': 5, 'regParam': 0.01}, 0.7179304617840288),
({'maxIter': 5, 'regParam': 0.05}, 0.7173397593133481)]