【spark】采用LogisticRegression(ML API篇)对MNIST的0-1数字进行识别

:ROC曲线概念

http://blog.csdn.net/abcjennifer/article/details/7359370

:Recall-Precision概念

http://blog.csdn.net/pirage/article/details/9851339

:下载MNIST数据集

http://yann.lecun.com/exdb/mnist/

:Logistic Regression:从入门到精通

http://www.tianyancha.com/research/LR_intro.pdf

:加载MNIST数据类

package com.bbw5.ml.spark.data

import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vector

/**
 * http://yann.lecun.com/exdb/mnist/
 *
 */
class MNISTData(val dataDir: String, val numCount: Int = 2) {
  val trainLabeFileName = "/train-labels-idx1-ubyte/train-labels.idx1-ubyte"
  val trainImageFileName = "/train-images-idx3-ubyte/train-images.idx3-ubyte"
  val testLabeFileName = "/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"
  val testImageFileName = "/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte"

  /**
   * 安全打开文件流方法
   */
  def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B =
    try {
      f(resource)
    } finally {
      resource.close()
    }

  def loadTrainLabelData(): Array[Byte] = {
    loadLabelData(trainLabeFileName)
  }

  def loadTestLabelData(): Array[Byte] = {
    loadLabelData(testLabeFileName)
  }

  /**
   * 加载MNIST train label数据
   */
  def loadLabelData(filename: String): Array[Byte] = {
    val file = new File(dataDir + filename)
    val in = new FileInputStream(file)
    val labelDS = new Array[Byte](file.length.toInt)
    using(new FileInputStream(file)) { source =>
      {
        in.read(labelDS)
      }
    }

    //32 bit integer  0x00000801(2049) magic number (MSB first--high endian) 
    val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt
    println(s"magicLabelNum=$magicLabelNum")
    //32 bit integer  60000            number of items 
    val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt
    println(s"numOfLabelItems=$numOfLabelItems")
    //打印测试数据
    for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) {
      println(s"image$index is $e")
    }

    labelDS
  }

  def loadTrainImageData(): Array[Byte] = {
    loadImageData(trainImageFileName)
  }

  def loadTestImageData(): Array[Byte] = {
    loadImageData(testImageFileName)
  }

  /**
   * 加载MNIST train data数据
   */
  def loadImageData(filename: String): Array[Byte] = {
    val file = new File(dataDir + filename)
    val in = new FileInputStream(file)
    val trainingDS = new Array[Byte](file.length.toInt)
    using(new FileInputStream(file)) { source =>
      {
        in.read(trainingDS)
      }
    }

    //32 bit integer  0x00000803(2051) magic number 
    val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt
    println(s"magicNum=$magicNum")
    //32 bit integer  60000            number of items 
    val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
    println(s"numOfItems=$numOfItems")
    //32 bit integer  28               number of rows 
    val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt
    println(s"numOfRows=$numOfRows")
    //32 bit integer  28               number of columns 
    val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt
    println(s"numOfCols=$numOfCols")

    println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows))

    trainingDS
  }

  def loadTrainData(): Array[(Double, Vector)] = {
    loadData(loadTrainLabelData, loadTrainImageData)
  }

  def loadTestData(): Array[(Double, Vector)] = {
    loadData(loadTestLabelData, loadTestImageData)
  }

  def loadData(loadLabelFunc: () => Array[Byte], loadImageFunc: () => Array[Byte]): Array[(Double, Vector)] = {
    val labelDS: Array[Byte] = loadLabelFunc()
    val labels = labelDS.drop(8)

    val trainingDS: Array[Byte] = loadImageFunc()
    val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
    val itemsBuffer = new ArrayBuffer[Array[Byte]]
    for (i <- 0 until numOfItems) {
      //16->16 + 28 * 28
      //16 + 28 * 28->16 + 2*28 * 28
      itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28)
    }
    println("numOfImages=" + itemsBuffer.length)
    val itemsArray = itemsBuffer.toArray
    val data = labels.zip(itemsArray)
    //打印测试数据概况
    println("image digit count:")
    data.groupBy(a => a._1).mapValues(b => b.size).foreach(println)

    //only 0/1 image
    data.filter(p => p._1 < numCount).map(p => (p._1.toDouble, Vectors.dense(p._2.map(c => c.toDouble))))
  }
}

object MNISTData {
  def loadTrainData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {
    new MNISTData(dataDir, numCount).loadTrainData()
  }
  def loadTestData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {
    new MNISTData(dataDir, numCount).loadTestData()
  }
}
:使用ML API进行分类

package com.bbw5.ml.spark

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import com.bbw5.ml.spark.data.MNISTData
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.tuning.TrainValidationSplit
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary
import org.apache.spark.ml.classification.LogisticRegressionModel
import java.util.Date

/**
 * 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类
 *
 * @author baibaiw5
 */
object LogisticRegression4MNIST {

  def main(args: Array[String]) {
    val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST")
    val sc = new SparkContext(sparkConf)
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)

    tvSplit(sc, sqlContext)
  }

  def printResult(bestModel: LogisticRegressionModel) {
    bestModel.save("D:/Develop/Model/MNIST-LR-"+System.currentTimeMillis())

    println("bestModel.params:" + bestModel.extractParamMap)
    
    val trainingSummary = bestModel.summary
    // Obtain the objective per iteration.
    val objectiveHistory = trainingSummary.objectiveHistory
    println("print lost in every step:")
    objectiveHistory.foreach(loss => println(loss))
    val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]

    // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
    println("print roc in every step:")
    binarySummary.roc.show(binarySummary.roc.count.toInt)

    println("print recall,precision every step:")
    binarySummary.pr.show(binarySummary.pr.count.toInt)
    //0.9409024010447935
    println("areaUnderROC="+binarySummary.areaUnderROC)
  }

  def tvSplit(sc: SparkContext, sqlContext: SQLContext) {
    val dataDir = "I:/DM-dataset/MNIST/"
    import sqlContext.implicits._
    val training = sc.parallelize(MNISTData.loadTrainData(dataDir), 4).toDF("label", "features").cache()
    training.describe("label").show
    val test = sc.parallelize(MNISTData.loadTestData(dataDir), 4).toDF("label", "features").cache()
    test.describe("label").show

    val lr = new LogisticRegression()

    val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0001, 0.01, 1.0)).addGrid(lr.fitIntercept).addGrid(lr.maxIter, Array(100)).addGrid(lr.elasticNetParam, Array(0.1, 0.5, 1.0)).build()

    // 80% of the data will be used for training and the remaining 20% for validation.  
    val trainValidationSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8)

    // Run train validation split, and choose the best set of parameters.
    val model = trainValidationSplit.fit(training)

    // Make predictions on test data. model is the model with combination of parameters
    // that performed best.
    val testDF = model.transform(test)
    testDF.select("label", "prediction").show()
    testDF.groupBy("label", "prediction").count().show()

    printResult(model.bestModel.asInstanceOf[LogisticRegressionModel])
  }

}
:在测试集上测试后的结果AUC为 0.9409024010447935,准确率为97%

        +-----+----------+-----+
        |label|prediction|count|
        +-----+----------+-----+
        |  1.0|       1.0| 1106|
        |  0.0|       0.0|  962|
        |  0.0|       1.0|   18|
        |  1.0|       0.0|   29|
        +-----+----------+-----+



  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值