Spark ml pipline交叉验证之决策树分类
1.1 模型训练
1.1.1 输入参数
{
"modelName ": "决策树分类_运动状态预测 ",
"numFolds ": "5 ",
"labelColumn ": "activityId ",
"maxDepths ": [
5,
10,
20
],
"maxBins ": [
32,
200,
300
]
}
1.1.2 训练代码
import com.cetc.common.conf.MachineLearnModel
import com.cetc.miner.compute.utils.{ModelUtils, Utils}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
import scala.collection.JavaConverters._
class DTCBestTrain {
val logger: org.apache.log4j.Logger = org.apache.log4j.Logger.getLogger(classOf[DTCBestTrain])
/**
* 决策树 分类模型训练
* @param df
* @param id
* @param name
* @param conf
* @param sparkSession
* @return
*/
def execute(df: DataFrame, id: String, name: String, conf: String, sparkSession: SparkSession): java.util.List[MachineLearnModel] = {
df.cache()
logger.info("训练集个数========="+ df.count())
val params = Utils.conf2Class(conf)
//ML的VectorAssembler是一个transformer,要求数据类型不能是string,将多列数据转化为单列的向量列,比如把age、income等等字段列合并成一个 userFea 向量列,方便后续训练
val assembler = new VectorAssembler().setInputCols(df.drop(params.getLabelColumn).columns).setOutputCol("features")
//标准化(归一化)
val standardScaler = new StandardScaler()
.setInputCol(assembler.getOutputCol)
.setOutputCol("scaledFeatures")
.setWithStd(true)//是否将数据缩放到单位标准差。
.setWithMean(false)//是否在缩放前使用平均值对数据进行居中。
//创建线性回归模型
val lr = new DecisionTreeClassifier()
.setFeaturesCol(assembler.getOutputCol) // 特征输入
.setLabelCol(params.getLabelColumn) // 要预测的值
//创建机器学习工作流
val pipeline = new Pipeline().setStages(Array(assembler, standardScaler, lr))
//创建多分类评估器,用于训练集的多次训练后的模型选择
val classificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol(params.getLabelColumn)//真实值
.setPredictionCol("prediction")//模型预测的值
.setMetricName("accuracy")//正确率
//获取最大迭代次数和正则参数,一共可以训练出(maxDepths*maxBins)个模型
import scala.collection.JavaConversions.asScalaBuffer
val paramMap = new ParamGridBuilder()
.addGrid(lr.getParam("maxDepth"), asScalaBuffer(params.getMaxDepths))
.addGrid(lr.getParam("maxBins"), asScalaBuffer(params.getMaxBins))
.build
//创建交叉验证器,他会把训练集分成NumFolds份,然后在其中(NumFolds-1)份里进行训练
//在其中一份里进行测试,针对上面的每一组参数都会训练出NumFolds个模型,最后选择一个
// 最优的模型
val crossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEstimatorParamMaps(paramMap)//设置模型需要的超参数组合
.setNumFolds(params.getNumFolds)//把训练集分成多少份数
.setEvaluator(classificationEvaluator)//设置评估器,用户评估测试结果数据
//模型训练
val model = crossValidator.fit(df)
// 最佳模型
val bestModel = model.bestModel.asInstanceOf[PipelineModel]
val dtcModel = bestModel.stages(2).asInstanceOf[DecisionTreeClassificationModel]
println("模型类型========", dtcModel.getClass)
//将模型封装成对象
val modelObject: MachineLearnModel = ModelUtils.saveModel(dtcModel, params.getModelName, 3, conf, 0, 0.0)
//保存模型到数据库
ModelUtils.model2mysql(modelObject)
return List(modelObject).asJava
}
}
1.2 模型评估
1.2.1 输入参数
{"labelColumn":""}
1.2.2 评估代码
import java.util
import com.cetc.common.conf.MachineLearnModel
import com.cetc.miner.compute.utils.{ModelUtils, Utils}
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator}
import org.apache.spark.sql.{DataFrame, SparkSession}
class DTCAssess {
val logger: org.apache.log4j.Logger = org.apache.log4j.Logger.getLogger(classOf[DTCAssess])
/**
* 决策树分类 模型评估
* @param df
* @param model
* @param id
* @param name
* @param conf
* @param sparkSession
* @return
*/
def execute(df: DataFrame, model: MachineLearnModel, id: String, name: String, conf: String, sparkSession: SparkSession): java.util.List[Double] = {
logger.info("测试集个数========="+ df.count())
val params = Utils.conf2Class(conf)
val userProfile = Utils.trans2SupervisedLearning(df, params.getLabelColumn)
val dtcModel = ModelUtils.loadModel[DecisionTreeClassificationModel](model)
//评估器
val classificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol(params.getLabelColumn)//真实值
.setPredictionCol("prediction")//模型预测的值
.setMetricName("accuracy")//正确率
val testDF = dtcModel.transform(userProfile)
testDF.show()
val accuracy = classificationEvaluator.evaluate(testDF)
logger.info("评估结果 正确率 accuracy==============" + accuracy)
ModelUtils.updateModel2mysql(model.getName, accuracy)
val list = new util.ArrayList[Double]()
list.add(accuracy)
return list
}
}