/*
此次操作的关键是将StringIndexer中加上参数设置setHandleInvalid("keep")
而OneHotEncoderEstimator加上参数设置setDropLast(true),默认值为true
*/
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoderEstimator, VectorAssembler}
import org.apache.spark.ml.Pipeline
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel
val data = (spark.read.format("csv")
.option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.load("/Affairs.csv"))
data.createOrReplaceTempView("res1")
val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
"gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
" from res1 "))
val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx").setHandleInvalid("keep")
)
val encoders = categoricals.map(
c => new OneHotEncoderEstimator().setInputCols(Array(s"${c}_idx")).setOutputCols(Array(s"${c}_enc")).setDropLast(true)
)
val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)
val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")
// Create an XGBoost Classifier
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")
// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
.addGrid(xgb.round, Array(10))
.addGrid(xgb.maxDepth, Array(10,20))
.addGrid(xgb.minChildWeight, Array(0.1))
.addGrid(xgb.gamma, Array(0.1))
.addGrid(xgb.subSample, Array(0.8))
.addGrid(xgb.colSampleByTree, Array(0.90))
.addGrid(xgb.alpha, Array(0.0))
.addGrid(xgb.lambda, Array(0.6))
.addGrid(xgb.scalePosWeight, Array(0.1))
.addGrid(xgb.eta, Array(0.4))
.addGrid(xgb.boosterType, Array("gbtree"))
.addGrid(xgb.objective, Array("binary:logistic"))
.build())
// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))
// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
.setLabelCol("affairs")
.setRawPredictionCol("prediction")
.setMetricName("areaUnderROC"))
// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(xgbParamGrid)
.setNumFolds(3)
.setSeed(0))
val Array(trainingData, testData) = df.randomSplit(Array(0.8, 0.2), seed=0)
// Create the model by fitting the training data
val xgbModel = cv.fit(trainingData)
// Test the data by scoring the model
val results = xgbModel.transform(testData)
/*
scala> results.select("affairs", "probabilities", "prediction").show(false)
+-------+-----------------------------------------+----------+
|affairs|probabilities |prediction|
+-------+-----------------------------------------+----------+
|0 |[0.953326940536499,0.04667305573821068] |0.0 |
|0 |[0.9757875800132751,0.02421243116259575] |0.0 |
|0 |[0.9757875800132751,0.02421243116259575] |0.0 |
|0 |[0.971818208694458,0.028181808069348335] |0.0 |
|0 |[0.9677507877349854,0.03224920853972435] |0.0 |
|0 |[0.9413658380508423,0.05863417685031891] |0.0 |
|0 |[0.9653270244598389,0.03467295691370964] |0.0 |
|0 |[0.9731520414352417,0.02684793807566166] |0.0 |
|0 |[0.9788634777069092,0.02113652601838112] |0.0 |
|0 |[0.9765547513961792,0.023445265367627144]|0.0 |
|0 |[0.9788634777069092,0.02113652601838112] |0.0 |
|0 |[0.9767565131187439,0.023243505507707596]|0.0 |
|0 |[0.9582042694091797,0.04179573059082031] |0.0 |
|0 |[0.9732727408409119,0.026727236807346344]|0.0 |
|0 |[0.9719892740249634,0.02801070734858513] |0.0 |
|0 |[0.9580785632133484,0.04192144423723221] |0.0 |
|0 |[0.948063850402832,0.05193614959716797] |0.0 |
|0 |[0.9788634777069092,0.02113652601838112] |0.0 |
|0 |[0.9681408405303955,0.03185916319489479] |0.0 |
|0 |[0.9661509394645691,0.03384905681014061] |0.0 |
+-------+-----------------------------------------+----------+
only showing top 20 rows
*/
/*
scala> results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+-------------+-------------+
|gender|children|gender_idx|children_idx| gender_enc| children_enc|
+------+--------+----------+------------+-------------+-------------+
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| yes| 0.0| 0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female| yes| 0.0| 0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| yes| 0.0| 0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female| yes| 0.0| 0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| no| 0.0| 1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female| yes| 0.0| 0.0|(3,[0],[1.0])|(3,[0],[1.0])|
+------+--------+----------+------------+-------------+-------------+
only showing top 20 rows
*/
val df1 = spark.createDataFrame(Seq(
(0.0,"hello",57.0,15.0,"foo",3.0,16.0,6.0,1.0), (0.0,"world",27.0,10.0,"bar",5.0,14.0,1.0,5.0))).toDF("affairs","gender","age","yearsmarried","children","religiousness","education","occupation","rating")
val pred_results = xgbModel.transform(df1)
pred_results.select("affairs", "probabilities", "prediction").show(false)
/*
scala> pred_results.select("affairs", "probabilities", "prediction").show(false)
+-------+----------------------------------------+----------+
|affairs|probabilities |prediction|
+-------+----------------------------------------+----------+
|0.0 |[0.9058669209480286,0.09413306415081024]|0.0 |
|0.0 |[0.9630419611930847,0.03695804998278618]|0.0 |
+-------+----------------------------------------+----------+
*/
pred_results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
/*
scala> pred_results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+----------+------------+
|gender|children|gender_idx|children_idx|gender_enc|children_enc|
+------+--------+----------+------------+----------+------------+
| hello| foo| 2.0| 2.0| (2,[],[])| (2,[],[])|
| world| bar| 2.0| 2.0| (2,[],[])| (2,[],[])|
+------+--------+----------+------------+----------+------------+
*/
//刻意将gender和children两列对应的取值设置成在训练集中未出现的情况,可以看到并没有出现报错的情况,也不影响预测结果
val df3 = spark.createDataFrame(Seq(
(0.0,"boy",57.0,15.0,"good",3.0,16.0,6.0,1.0),
(0.0,"man",37.0,15.0,"great",3.0,17.0,5.0,5.0), (0.0,"girl",27.0,10.0,"bad",5.0,14.0,1.0,5.0))).toDF("affairs","gender","age","yearsmarried","children","religiousness","education","occupation","rating")
val pred_results_3 = xgbModel.transform(df3)
/*
scala> val pred_results_3 = xgbModel.transform(df3)
pred_results_3: org.apache.spark.sql.DataFrame = [affairs: double, gender: string ... 14 more fields]
scala> pred_results_3.select("affairs", "probabilities", "prediction").show(false)
+-------+----------------------------------------+----------+
|affairs|probabilities |prediction|
+-------+----------------------------------------+----------+
|0.0 |[0.9058669209480286,0.09413306415081024]|0.0 |
|0.0 |[0.968347430229187,0.03165258467197418] |0.0 |
|0.0 |[0.9630419611930847,0.03695804998278618]|0.0 |
+-------+----------------------------------------+----------+
scala> pred_results_3.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+----------+------------+
|gender|children|gender_idx|children_idx|gender_enc|children_enc|
+------+--------+----------+------------+----------+------------+
| boy| good| 2.0| 2.0| (2,[],[])| (2,[],[])|
| man| great| 2.0| 2.0| (2,[],[])| (2,[],[])|
| girl| bad| 2.0| 2.0| (2,[],[])| (2,[],[])|
+------+--------+----------+------------+----------+------------+
*/
转载于:https://my.oschina.net/kyo4321/blog/2994576