import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionWithSGD}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.optimization._
/**
* Created by simon on 2017/5/8.
*/
object genderClassificationWithLogisticRegression {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
conf.setAppName("genderClassification").setMaster("local[2]")
val sc = new SparkContext(conf)
val trainData = sc.textFile("file:\\E:\\test.csv")
val parsedTrainData = trainData.map { line =>
val parts= line.split("\\|")
val label = toInt(parts(1))
val features = Vectors.dense(parts.slice(6,parts.length-1).map(_.toDouble))
LabeledPoint(label, features)
}.cache()
val splits = parsedTrainData.randomSplit(Array(0.7, 0.3), seed = 11L)
val training = splits(0)
val testing = splits(1)
val model = new LogisticRegressionWithSGD()
model.optimizer.setNumIterations(500).setUpdater(new SimpleUpdater()).setStepSize(0.001).setMiniBatchFraction(0.02)
val trained = model.run(training)
val prediction = trained.predict(testing.map(_.features))
val predictionAndLabels = prediction.zip(testing.map(_.label))
val metrics = new BinaryClassificationMetrics(predictionAndLabels)
val auROC = metrics.areaUnderROC
println("Area under ROC = " + auROC)
}
def toInt(s: String): Int = {
if (s == "m") 1 else 0
}
}