package com.huihex.sparkmllib
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
/**
* Created by wall-e on 2017/4/1.
*/
object Logistic_regression {
/**
* 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。
* logistic回归的因变量可以是二分类的,也可以是多分类的。
* @param args
*/
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("逻辑回归").setMaster("local")
val sc = new SparkContext(conf)
//读取数据
//每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类
//这里我们用LabeledPoint来存储标签列和特征列
val data = sc.textFile("data\\iris.txt")
/*
LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。
这里,先把莺尾花的分类进行变换,"Iris-setosa"对应分类0,"Iris-versicolor"对应分类1,其余对应分类2;
然后获取莺尾花的4个特征,存储在Vector中。
*/
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(
if(parts(4)=="Iris-setosa") 0.toDouble
else if (parts(4) =="Iris-versicolor") 1.toDouble
else 2.toDouble,
Vectors.dense(parts(0).toDouble,parts(1).toDouble,parts(2).toDouble,parts(3).toDouble)
)
}
//打印读取并处理后的数据
parsedData.foreach { x => println(x) }
/**
* 首先进行数据集的划分,这里划分60%的训练集和40%的测试集:
*/
val splits = parsedData.randomSplit(Array(0.6,0.4),seed = 11L)
val training = splits(0).cache()
val test = splits(1)
/**
* 然后,构建逻辑斯蒂模型,用set的方法设置参数,比如说分类的数目,这里可以实现多分类逻辑斯蒂模型
*/
val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(training)
/**
* 接下来,调用多分类逻辑斯蒂模型用的predict方法对测试数据进行预测,并把结果保存在MulticlassMetrics中。
* 这里的模型全名为LogisticRegressionWithLBFGS,加上了LBFGS,表示Limited-memory BFGS。
* 其中,BFGS是求解非线性优化问题(L(w)求极大值)的方法,是一种秩-2更新,
* 以其发明者Broyden, Fletcher, Goldfarb和Shanno的姓氏首字母命名。
*/
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
/**
* 这里,采用了test部分的数据每一行都分为标签label和特征features,
* 然后利用map方法,对每一行的数据进行model.predict(features)操作,获得预测值。
* 并把预测值和真正的标签放到predictionAndLabels中。我们可以打印出具体的结果数据来看一下:
*/
predictionAndLabels.foreach(x =>(println(x)))
/**
* 模型评估
* 模型预测的准确性
*/
val metrics = new MulticlassMetrics(predictionAndLabels)
val precision = metrics.precision
println("Precision = " + precision)
}
}
iris数据集下载链接(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)
......
17/04/06 22:16:29 INFO SparkContext: Created broadcast 1 from broadcast at DAGScheduler.scala:996
17/04/06 22:16:29 INFO DAGScheduler: Submitting 1 missing tasks from ResultStage 0 (MapPartitionsRDD[2] at map at Logistic_regression.scala:29)
17/04/06 22:16:29 INFO TaskSchedulerImpl: Adding task set 0.0 with 1 tasks
17/04/06 22:16:29 INFO TaskSetManager: Starting task 0.0 in stage 0.0 (TID 0, localhost, executor driver, partition 0, PROCESS_LOCAL, 5983 bytes)
17/04/06 22:16:29 INFO Executor: Running task 0.0 in stage 0.0 (TID 0)
17/04/06 22:16:29 INFO HadoopRDD: Input split: file:/D:/huihex-spark/data/iris.txt:0+4698
17/04/06 22:16:29 INFO deprecation: mapred.tip.id is deprecated. Instead, use mapreduce.task.id
17/04/06 22:16:29 INFO deprecation: mapred.task.id is deprecated. Instead, use mapreduce.task.attempt.id
17/04/06 22:16:29 INFO deprecation: mapred.task.is.map is deprecated. Instead, use mapreduce.task.ismap
17/04/06 22:16:29 INFO deprecation: mapred.task.partition is deprecated. Instead, use mapreduce.task.partition
17/04/06 22:16:29 INFO deprecation: mapred.job.id is deprecated. Instead, use mapreduce.job.id
(0.0,[5.1,3.5,1.4,0.2])
(0.0,[4.9,3.0,1.4,0.2])
(0.0,[4.7,3.2,1.3,0.2])
(0.0,[4.6,3.1,1.5,0.2])
......
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 72.0 (TID 72). 995 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 72.0 (TID 72) in 21 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 72.0, whose tasks have all completed, from pool
17/04/06 22:16:33 INFO DAGScheduler: ResultStage 72 (foreach at Logistic_regression.scala:66) finished in 0.022 s
(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
......
17/04/06 22:16:33 INFO Executor: Running task 0.0 in stage 76.0 (TID 76)
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Getting 1 non-empty blocks out of 1 blocks
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Started 0 remote fetches in 1 ms
Precision = 0.9615384615384616
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 76.0 (TID 76). 1807 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 76.0 (TID 76) in 10 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 76.0, whose tasks have all completed, from pool
参考文档:http://mocom.xmu.edu.cn/article/show/58578f482b2730e00d70f9fc/0/1