spark运用逻辑回归算法操作Titanic数据集

/*

参考资料:
使用scala部署XGBoost算法:http://bailiwick.io/2017/08/21/using-xgboost-with-the-titanic-dataset-from-kaggle/
使用Java部署逻辑回归算法:https://blog.csdn.net/javafreely/article/details/81813492
使用scala操作iris数据集:http://dblab.xmu.edu.cn/blog/1510-2/
Titanic数据集下载地址:https://www.kaggle.com/c/titanic/data

*/


import org.apache.spark.ml.feature.{Imputer, StandardScaler}
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoderEstimator}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.classification.{LogisticRegressionModel,LogisticRegressionParams,LogisticRegressionSummary}

val titanicDFCsv  = (spark.read.format("csv")
  .option("sep", ",")
  .option("inferSchema", "true")
  .option("header", "true")
  .load("/titanic_data/train.csv"))


/*
scala> titanicDFCsv.printSchema
root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)
*/ 


//将Cabin字段空值的赋值为0,非空的赋值为1
val TrainingData = titanicDFCsv.withColumn("Cabin", when($"Cabin".isNull, 0).otherwise(1))

/*
scala> TrainingData.show
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|    0|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|    1|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925|    0|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1|    1|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05|    0|       S|
|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|          330877| 8.4583|    0|       Q|
|          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|    1|       S|
|          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075|    0|       S|
|          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333|    0|       S|
|         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708|    0|       C|
|         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|    1|       S|
|         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55|    1|       S|
|         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05|    0|       S|
|         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275|    0|       S|
|         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542|    0|       S|
|         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0|    0|       S|
|         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125|    0|       Q|
|         18|       1|     2|Williams, Mr. Cha...|  male|null|    0|    0|          244373|   13.0|    0|       S|
|         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0|    0|       S|
|         20|       1|     3|Masselmani, Mrs. ...|female|null|    0|    0|            2649|  7.225|    0|       C|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
only showing top 20 rows
*/

//统计各列字段缺失值个数
/*
参考资料:https://stackoverflow.com/questions/44413132/count-the-number-of-missing-values-in-a-dataframe-spark/44413456#44413456
*/

/*
scala> TrainingData.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|    0|       2|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
*/

TrainingData.createOrReplaceTempView("trainFeatures")

spark.sql("SELECT Pclass,Embarked,percentile_approx(Fare, 0.5) AS Median_Fare FROM trainFeatures WHERE Fare IS NOT NULL AND Pclass = 1 GROUP BY Pclass,Embarked").show()

/*
scala> spark.sql("SELECT Pclass,Embarked,percentile_approx(Fare, 0.5) AS Median_Fare FROM trainFeatures WHERE Fare IS NOT NULL AND Pclass = 1 GROUP BY Pclass,Embarked").show()
+------+--------+-----------+
|Pclass|Embarked|Median_Fare|
+------+--------+-----------+
|     1|    null|       80.0|
|     1|       Q|       90.0|
|     1|       C|    78.2667|
|     1|       S|       52.0|
+------+--------+-----------+
*/

//Embarked缺失值使用中位数进行填充
val trainEmbarked = TrainingData.na.fill("C",Seq("Embarked"))
trainEmbarked.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show
/*
scala> trainEmbarked.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|    0|       0|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
*/


//对数值型变量Age进行缺失值填充,默认使用均值mean进行填充,若设置setStrategy("median")则使用中位数进行填充,此处采用的是均值填充
val imputer = (new Imputer()
  .setInputCols(Array("Age"))
  .setOutputCols(Array("Age_imp")))


//接下来对分类变量进行独热编码,最新的spark2.3.2版本中运用OneHotEncoderEstimator可以避免当测试集中的分类变量值
//与训练集中存在差异时报错的情况

/*
参考资料:
http://spark.apache.org/docs/2.3.2/ml-features.html#onehotencoderestimator
https://issues.apache.org/jira/browse/SPARK-13030
https://www.cnblogs.com/realzjx/p/5854425.html

scikit-learn中OneHotEncoder官方文档:
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html
*/



// Convert the categorical (string) values into numeric values
//此次要注意的是需要设置setHandleInvalid("keep")
//此次操作的关键是将StringIndexer中加上参数设置setHandleInvalid("keep")
//而OneHotEncoderEstimator加上参数设置setDropLast(true),默认值为true,设置最后一个向量元素是否包含,false则包含,true则不包含
//此处setDropLast设置为true,则不包含最后一个元素
/*
The last category is not included by default (configurable via dropLast) because it makes the vector entries sum up to one, and hence linearly dependent.
*/

val genderIndexer = new StringIndexer().setInputCol("Sex").setOutputCol("SexIndex").setHandleInvalid("keep")
val embarkIndexer = new StringIndexer().setInputCol("Embarked").setOutputCol("EmbarkIndex").setHandleInvalid("keep")


// Convert the numerical index columns into One Hot columns
// The One Hot columns are binary {0,1} values of the categories,这里使用的是OneHotEncoderEstimator,注意次数的对应是Array类型
val genderEncoder = new OneHotEncoderEstimator().setInputCols(Array("SexIndex")).setOutputCols(Array("SexVec"))
val embarkEncoder = new OneHotEncoderEstimator().setInputCols(Array("EmbarkIndex")).setOutputCols(Array("EmbarkVec"))



// Create a vector of the features.
val vectorAssembler = new VectorAssembler().setInputCols(Array("Pclass", "SibSp", "Parch", "Fare", "Cabin", "Age_imp", "SexVec", "EmbarkVec")).setOutputCol("features")

//将拼接的字段数据统一进行标准化
val scaler = (new StandardScaler()
               .setInputCol("features")
               .setOutputCol("scaledFeatures")
               .setWithStd(true)
               .setWithMean(false))

val trainingFeaturesPipeline = (new Pipeline()
  .setStages(Array(imputer,genderIndexer,embarkIndexer,genderEncoder,embarkEncoder,vectorAssembler,scaler)))


val trainingFeaturesDF = trainingFeaturesPipeline.fit(trainEmbarked).transform(trainEmbarked)

// Now that the data has been prepared, let's split the dataset into a training and test dataframe
val Array(trainDF, testDF) = trainingFeaturesDF.randomSplit(Array(0.8, 0.2),seed = 12345)



val lr = (new LogisticRegression()
        .setMaxIter(100)
        .setRegParam(0.1)
        .setFeaturesCol("scaledFeatures")
        .setLabelCol("Survived")
        .setElasticNetParam(0))

val pipeline = (new Pipeline()
  .setStages(Array(lr)))

val paramGrid = (new ParamGridBuilder()
  .addGrid(lr.regParam, Array(0.01,0.05,0.1))
  .build())

// Setup the binary classifier evaluator
val evaluator = (new BinaryClassificationEvaluator()
   .setLabelCol("Survived")
   .setRawPredictionCol("prediction")
   .setMetricName("areaUnderROC"))

val cv = (new CrossValidator()
       .setEstimator(pipeline)
       .setEvaluator(evaluator) 
       .setEstimatorParamMaps(paramGrid)
       .setNumFolds(3))

// Run cross-validation, and choose the best set of parameters.
val cvModel = cv.fit(trainDF)

val test = cvModel.transform(testDF)
test.select("PassengerId", "Survived", "probability", "prediction")


/*
scala> test.select("PassengerId", "Survived", "probability", "prediction").show
+-----------+--------+--------------------+----------+
|PassengerId|Survived|         probability|prediction|
+-----------+--------+--------------------+----------+
|          5|       0|[0.88950692008834...|       0.0|
|          8|       0|[0.85683367108559...|       0.0|
|          9|       1|[0.41512197710691...|       1.0|
|         16|       1|[0.42466192593405...|       1.0|
|         17|       0|[0.81730567076689...|       0.0|
|         18|       1|[0.80460388469234...|       0.0|
|         36|       0|[0.76909426604402...|       0.0|
|         41|       0|[0.52095325993076...|       0.0|
|         43|       0|[0.81599634202170...|       0.0|
|         52|       0|[0.85728031095300...|       0.0|
|         57|       1|[0.26745049567398...|       1.0|
|         67|       1|[0.18197345040904...|       1.0|
|         73|       0|[0.75836226515332...|       0.0|
|         75|       1|[0.87558683140555...|       0.0|
|         77|       0|[0.87813924471160...|       0.0|
|         80|       1|[0.43291509090967...|       1.0|
|         81|       0|[0.85960968310027...|       0.0|
|         89|       1|[0.10470112282959...|       1.0|
|         94|       0|[0.88149513319149...|       0.0|
|        102|       0|[0.87813924471160...|       0.0|
+-----------+--------+--------------------+----------+
only showing top 20 rows
*/

// What was the overall accuracy of the model, using AUC
val auc = evaluator.evaluate(test)
println("----AUC--------")
println("auc="+auc)


//just save the best model
val bestPipelineModel  = cvModel.bestModel.asInstanceOf[PipelineModel]
bestPipelineModel.save("/Titanic_best_model_20181227")

val bestModel= cvModel.bestModel.asInstanceOf[PipelineModel]
val lrModel = bestModel.stages(0).asInstanceOf[LogisticRegressionModel] //此处除了写成stages(0)以外,还可以采用
//通用的写法.stages.last那样就不用事先在模型文件的stages目录下查看算法到底在哪一步

//输出相应系数
println("Coefficients: " + lrModel.coefficientMatrix + "Intercept: "+lrModel.interceptVector+ "numClasses: "+lrModel.numClasses+"numFeatures: "+lrModel.numFeatures)

//计算bestRegParam 
val bestRegParam = lrModel.getRegParam

//获取二分类相应指标统计值
val summary = lrModel.binarySummary


//计算精确率、召回率与准确率
val precision = summary.weightedPrecision
val recall = summary.weightedRecall
val accuracy = summary.accuracy


/*
scala> val precision = summary.weightedPrecision
precision: Double = 0.8051862498502815

scala> val recall = summary.weightedRecall
recall: Double = 0.8066378066378066

scala> val accuracy = summary.accuracy
accuracy: Double = 0.8066378066378066
*/

转载于:https://my.oschina.net/kyo4321/blog/2994570

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值