SparkMLlib逻辑斯蒂回归分类器简单案例

package com.huihex.sparkmllib

import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
/**
  * Created by wall-e on 2017/4/1.
  */
object Logistic_regression {
  /**
    * 逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。
    * logistic回归的因变量可以是二分类的,也可以是多分类的。
    * @param args
    */
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("逻辑回归").setMaster("local")
    val sc = new SparkContext(conf)
    //读取数据
    //每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类
    //这里我们用LabeledPoint来存储标签列和特征列
    val data = sc.textFile("data\\iris.txt")
    /*
    LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。
    这里,先把莺尾花的分类进行变换,"Iris-setosa"对应分类0,"Iris-versicolor"对应分类1,其余对应分类2;
    然后获取莺尾花的4个特征,存储在Vector中。
     */
    val parsedData = data.map { line =>
      val parts = line.split(',')
      LabeledPoint(
        if(parts(4)=="Iris-setosa") 0.toDouble
        else if (parts(4) =="Iris-versicolor") 1.toDouble
        else 2.toDouble,
        Vectors.dense(parts(0).toDouble,parts(1).toDouble,parts(2).toDouble,parts(3).toDouble)
      )
    }
    //打印读取并处理后的数据
    parsedData.foreach { x => println(x) }

    /**
      * 首先进行数据集的划分,这里划分60%的训练集和40%的测试集:
      */
    val splits = parsedData.randomSplit(Array(0.6,0.4),seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)
    /**
      * 然后,构建逻辑斯蒂模型,用set的方法设置参数,比如说分类的数目,这里可以实现多分类逻辑斯蒂模型
      */
    val model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(training)
    /**
      * 接下来,调用多分类逻辑斯蒂模型用的predict方法对测试数据进行预测,并把结果保存在MulticlassMetrics中。
      * 这里的模型全名为LogisticRegressionWithLBFGS,加上了LBFGS,表示Limited-memory BFGS。
      * 其中,BFGS是求解非线性优化问题(L(w)​求极大值)的方法,是一种秩-2更新,
      * 以其发明者Broyden, Fletcher, Goldfarb和Shanno的姓氏首字母命名。
      */
    val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }
    /**
      * 这里,采用了test部分的数据每一行都分为标签label和特征features,
      * 然后利用map方法,对每一行的数据进行model.predict(features)操作,获得预测值。
      * 并把预测值和真正的标签放到predictionAndLabels中。我们可以打印出具体的结果数据来看一下:
      */
    predictionAndLabels.foreach(x =>(println(x)))
    /**
      * 模型评估
      * 模型预测的准确性
      */
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val precision = metrics.precision
    println("Precision = " + precision)
  }
}

iris数据集下载链接(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

......
17/04/06 22:16:29 INFO SparkContext: Created broadcast 1 from broadcast at DAGScheduler.scala:996
17/04/06 22:16:29 INFO DAGScheduler: Submitting 1 missing tasks from ResultStage 0 (MapPartitionsRDD[2] at map at Logistic_regression.scala:29)
17/04/06 22:16:29 INFO TaskSchedulerImpl: Adding task set 0.0 with 1 tasks
17/04/06 22:16:29 INFO TaskSetManager: Starting task 0.0 in stage 0.0 (TID 0, localhost, executor driver, partition 0, PROCESS_LOCAL, 5983 bytes)
17/04/06 22:16:29 INFO Executor: Running task 0.0 in stage 0.0 (TID 0)
17/04/06 22:16:29 INFO HadoopRDD: Input split: file:/D:/huihex-spark/data/iris.txt:0+4698
17/04/06 22:16:29 INFO deprecation: mapred.tip.id is deprecated. Instead, use mapreduce.task.id
17/04/06 22:16:29 INFO deprecation: mapred.task.id is deprecated. Instead, use mapreduce.task.attempt.id
17/04/06 22:16:29 INFO deprecation: mapred.task.is.map is deprecated. Instead, use mapreduce.task.ismap
17/04/06 22:16:29 INFO deprecation: mapred.task.partition is deprecated. Instead, use mapreduce.task.partition
17/04/06 22:16:29 INFO deprecation: mapred.job.id is deprecated. Instead, use mapreduce.job.id
(0.0,[5.1,3.5,1.4,0.2])
(0.0,[4.9,3.0,1.4,0.2])
(0.0,[4.7,3.2,1.3,0.2])
(0.0,[4.6,3.1,1.5,0.2])
......
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
(1.0,1.0)
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 72.0 (TID 72). 995 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 72.0 (TID 72) in 21 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 72.0, whose tasks have all completed, from pool 
17/04/06 22:16:33 INFO DAGScheduler: ResultStage 72 (foreach at Logistic_regression.scala:66) finished in 0.022 s
(2.0,2.0)
(2.0,2.0)
(2.0,2.0)
......
17/04/06 22:16:33 INFO Executor: Running task 0.0 in stage 76.0 (TID 76)
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Getting 1 non-empty blocks out of 1 blocks
17/04/06 22:16:33 INFO ShuffleBlockFetcherIterator: Started 0 remote fetches in 1 ms
Precision = 0.9615384615384616
17/04/06 22:16:33 INFO Executor: Finished task 0.0 in stage 76.0 (TID 76). 1807 bytes result sent to driver
17/04/06 22:16:33 INFO TaskSetManager: Finished task 0.0 in stage 76.0 (TID 76) in 10 ms on localhost (executor driver) (1/1)
17/04/06 22:16:33 INFO TaskSchedulerImpl: Removed TaskSet 76.0, whose tasks have all completed, from pool 

参考文档:http://mocom.xmu.edu.cn/article/show/58578f482b2730e00d70f9fc/0/1

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值