使用XGBoost结合OneHotEncoderEstimator操作Affairs数据集


/*

此次操作的关键是将StringIndexer中加上参数设置setHandleInvalid("keep")
而OneHotEncoderEstimator加上参数设置setDropLast(true),默认值为true

*/

import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoderEstimator, VectorAssembler}
import org.apache.spark.ml.Pipeline
import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoostClassificationModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.PipelineModel

val data = (spark.read.format("csv")
  .option("sep", ",")
  .option("inferSchema", "true")
  .option("header", "true")
  .load("/Affairs.csv"))

data.createOrReplaceTempView("res1")

val affairs = "case when affairs>0 then 1 else 0 end as affairs,"
val df = (spark.sql("select " + affairs +
  "gender,age,yearsmarried,children,religiousness,education,occupation,rating" +
  " from res1 "))


val categoricals = df.dtypes.filter(_._2 == "StringType") map (_._1)
val indexers = categoricals.map(
  c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx").setHandleInvalid("keep")
)

val encoders = categoricals.map(
  c => new OneHotEncoderEstimator().setInputCols(Array(s"${c}_idx")).setOutputCols(Array(s"${c}_enc")).setDropLast(true)  
)


val colArray_enc = categoricals.map(x => x + "_enc")
val colArray_numeric = df.dtypes.filter(_._2 != "StringType") map (_._1)

val final_colArray = (colArray_numeric ++ colArray_enc).filter(!_.contains("affairs"))
val vectorAssembler = new VectorAssembler().setInputCols(final_colArray).setOutputCol("features")


// Create an XGBoost Classifier 
val xgb = new XGBoostEstimator(Map("num_class" -> 2, "num_rounds" -> 5, "objective" -> "binary:logistic", "booster" -> "gbtree")).setLabelCol("affairs").setFeaturesCol("features")


// XGBoost paramater grid
val xgbParamGrid = (new ParamGridBuilder()
   .addGrid(xgb.round, Array(10))
   .addGrid(xgb.maxDepth, Array(10,20))
   .addGrid(xgb.minChildWeight, Array(0.1))
   .addGrid(xgb.gamma, Array(0.1))
   .addGrid(xgb.subSample, Array(0.8))
   .addGrid(xgb.colSampleByTree, Array(0.90))
   .addGrid(xgb.alpha, Array(0.0))
   .addGrid(xgb.lambda, Array(0.6))
   .addGrid(xgb.scalePosWeight, Array(0.1))
   .addGrid(xgb.eta, Array(0.4))
   .addGrid(xgb.boosterType, Array("gbtree"))
   .addGrid(xgb.objective, Array("binary:logistic")) 
   .build())

// Create the XGBoost pipeline
val pipeline = new Pipeline().setStages(indexers ++ encoders ++ Array(vectorAssembler, xgb))


// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
   .setLabelCol("affairs")
   .setRawPredictionCol("prediction")
   .setMetricName("areaUnderROC"))


// Create the Cross Validation pipeline, using XGBoost as the estimator, the
// Binary Classification evaluator, and xgbParamGrid for hyperparameters
val cv = (new CrossValidator()
   .setEstimator(pipeline)
   .setEvaluator(evaluator)
   .setEstimatorParamMaps(xgbParamGrid)
   .setNumFolds(3)
   .setSeed(0))


val Array(trainingData, testData) = df.randomSplit(Array(0.8, 0.2), seed=0)


 // Create the model by fitting the training data
val xgbModel = cv.fit(trainingData)

 // Test the data by scoring the model
val results = xgbModel.transform(testData)

/*
scala> results.select("affairs", "probabilities", "prediction").show(false)
+-------+-----------------------------------------+----------+
|affairs|probabilities                            |prediction|
+-------+-----------------------------------------+----------+
|0      |[0.953326940536499,0.04667305573821068]  |0.0       |
|0      |[0.9757875800132751,0.02421243116259575] |0.0       |
|0      |[0.9757875800132751,0.02421243116259575] |0.0       |
|0      |[0.971818208694458,0.028181808069348335] |0.0       |
|0      |[0.9677507877349854,0.03224920853972435] |0.0       |
|0      |[0.9413658380508423,0.05863417685031891] |0.0       |
|0      |[0.9653270244598389,0.03467295691370964] |0.0       |
|0      |[0.9731520414352417,0.02684793807566166] |0.0       |
|0      |[0.9788634777069092,0.02113652601838112] |0.0       |
|0      |[0.9765547513961792,0.023445265367627144]|0.0       |
|0      |[0.9788634777069092,0.02113652601838112] |0.0       |
|0      |[0.9767565131187439,0.023243505507707596]|0.0       |
|0      |[0.9582042694091797,0.04179573059082031] |0.0       |
|0      |[0.9732727408409119,0.026727236807346344]|0.0       |
|0      |[0.9719892740249634,0.02801070734858513] |0.0       |
|0      |[0.9580785632133484,0.04192144423723221] |0.0       |
|0      |[0.948063850402832,0.05193614959716797]  |0.0       |
|0      |[0.9788634777069092,0.02113652601838112] |0.0       |
|0      |[0.9681408405303955,0.03185916319489479] |0.0       |
|0      |[0.9661509394645691,0.03384905681014061] |0.0       |
+-------+-----------------------------------------+----------+
only showing top 20 rows
*/

/*
scala> results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+-------------+-------------+
|gender|children|gender_idx|children_idx|   gender_enc| children_enc|
+------+--------+----------+------------+-------------+-------------+
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|     yes|       0.0|         0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female|     yes|       0.0|         0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|     yes|       0.0|         0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female|     yes|       0.0|         0.0|(3,[0],[1.0])|(3,[0],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|      no|       0.0|         1.0|(3,[0],[1.0])|(3,[1],[1.0])|
|female|     yes|       0.0|         0.0|(3,[0],[1.0])|(3,[0],[1.0])|
+------+--------+----------+------------+-------------+-------------+
only showing top 20 rows
*/


val df1 = spark.createDataFrame(Seq(
      (0.0,"hello",57.0,15.0,"foo",3.0,16.0,6.0,1.0),  (0.0,"world",27.0,10.0,"bar",5.0,14.0,1.0,5.0))).toDF("affairs","gender","age","yearsmarried","children","religiousness","education","occupation","rating")


val pred_results = xgbModel.transform(df1)

pred_results.select("affairs", "probabilities", "prediction").show(false)
/*

scala> pred_results.select("affairs", "probabilities", "prediction").show(false)
+-------+----------------------------------------+----------+
|affairs|probabilities                           |prediction|
+-------+----------------------------------------+----------+
|0.0    |[0.9058669209480286,0.09413306415081024]|0.0       |
|0.0    |[0.9630419611930847,0.03695804998278618]|0.0       |
+-------+----------------------------------------+----------+

*/

pred_results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
/*
scala> pred_results.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+----------+------------+
|gender|children|gender_idx|children_idx|gender_enc|children_enc|
+------+--------+----------+------------+----------+------------+
| hello|     foo|       2.0|         2.0| (2,[],[])|   (2,[],[])|
| world|     bar|       2.0|         2.0| (2,[],[])|   (2,[],[])|
+------+--------+----------+------------+----------+------------+
*/


//刻意将gender和children两列对应的取值设置成在训练集中未出现的情况,可以看到并没有出现报错的情况,也不影响预测结果
val df3 = spark.createDataFrame(Seq(
      (0.0,"boy",57.0,15.0,"good",3.0,16.0,6.0,1.0),
      (0.0,"man",37.0,15.0,"great",3.0,17.0,5.0,5.0),          (0.0,"girl",27.0,10.0,"bad",5.0,14.0,1.0,5.0))).toDF("affairs","gender","age","yearsmarried","children","religiousness","education","occupation","rating")

val pred_results_3 = xgbModel.transform(df3)

/*
scala> val pred_results_3 = xgbModel.transform(df3)
pred_results_3: org.apache.spark.sql.DataFrame = [affairs: double, gender: string ... 14 more fields]

scala> pred_results_3.select("affairs", "probabilities", "prediction").show(false)
+-------+----------------------------------------+----------+
|affairs|probabilities                           |prediction|
+-------+----------------------------------------+----------+
|0.0    |[0.9058669209480286,0.09413306415081024]|0.0       |
|0.0    |[0.968347430229187,0.03165258467197418] |0.0       |
|0.0    |[0.9630419611930847,0.03695804998278618]|0.0       |
+-------+----------------------------------------+----------+


scala> pred_results_3.select("gender", "children", "gender_idx","children_idx","gender_enc","children_enc").show()
+------+--------+----------+------------+----------+------------+
|gender|children|gender_idx|children_idx|gender_enc|children_enc|
+------+--------+----------+------------+----------+------------+
|   boy|    good|       2.0|         2.0| (2,[],[])|   (2,[],[])|
|   man|   great|       2.0|         2.0| (2,[],[])|   (2,[],[])|
|  girl|     bad|       2.0|         2.0| (2,[],[])|   (2,[],[])|
+------+--------+----------+------------+----------+------------+
*/

转载于:https://my.oschina.net/kyo4321/blog/2994576

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值