【spark+python】采用LogisticRegression(MLLib)对MNIST的0-1数字进行识别

:下载数据集

http://yann.lecun.com/exdb/mnist/

:使用python查看原始数据集的图片

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
image1=[]
mnistData = open('D:/DM-datasets/MNIST/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte','rb')
mnistData.seek(0,0)
mnistData.read(16)
for i in range(0,28*28):
	b=int(mnistData.read(1).encode('hex'),16)
	image1.append(b)

print image1
arr = np.array(image1)
plt.imshow(arr.reshape((28, 28)),cmap=cm.Greys_r)
plt.axis('off')
plt.show()

:使用spark做logistic regression

package com.bbw5.ml.spark

import java.io.ByteArrayInputStream
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import javax.imageio.ImageIO
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.{ LogisticRegressionWithLBFGS, LogisticRegressionModel }
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.SparkConf
import org.apache.spark.mllib.evaluation.MulticlassMetrics
/**
 * 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类
 *
 * @author baibaiw5
 */
object LogisticRegression4MNIST {

  def main(args: Array[String]) {
    val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST")
    val sc = new SparkContext(sparkConf)

    val labelDS: Array[Byte] = loadMNISTTrainLabelDataSet()
    val labels = labelDS.drop(8)

    val trainingDS: Array[Byte] = loadMNISTTrainDataSet()
    val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
    val itemsBuffer = new ArrayBuffer[Array[Byte]]
    for (i <- 0 until numOfItems) {
      //16->16 + 28 * 28
      //16 + 28 * 28->16 + 2*28 * 28
      itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28)
    }
    println("itemsBuffer=" + itemsBuffer.length)
    val itemsArray = itemsBuffer.toArray
    val trainingData = labels.zip(itemsArray)
    //打印测试数据概况
    trainingData.groupBy(a => a._1).mapValues(b => b.size).foreach(println)

    //只对0,1的图片进行分类
    val rawData = trainingData.filter(p => p._1 == 0 || p._1 == 1).map(p => new LabeledPoint(p._1, Vectors.dense(p._2.map(c => c.toDouble))))

    val data = sc.parallelize(rawData, 4)
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    // 训练模型
    val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(training)

    // 对测试集进行验证
    val predictionAndLabels = test.map {
      case LabeledPoint(label, features) =>
        val prediction = model.predict(features)
        (prediction, label)
    }

    // 打印指标
    val metrics = new BinaryClassificationMetrics(predictionAndLabels)
    val auROC = metrics.areaUnderROC()
    //Area under ROC = 0.7238888888888889
    println("Area under ROC = " + auROC)

    val metrics2 = new MulticlassMetrics(predictionAndLabels)
    //Precision = 0.7257683215130024,Recall = 0.7257683215130024
    println("Precision = " + metrics2.precision+",Recall = " + metrics2.recall)
  }

  /**
   * 安全打开文件流方法
   */
  def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B =
    try {
      f(resource)
    } finally {
      resource.close()
    }

  /**
   * 加载MNIST train label数据
   */
  def loadMNISTTrainLabelDataSet(): Array[Byte] = {
    val file = new File("D:/DM-datasets/MNIST/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte")
    val in = new FileInputStream(file)
    val labelDS = new Array[Byte](file.length.toInt)
    using(new FileInputStream(file)) { source =>
      {
        in.read(labelDS)
      }
    }

    //32 bit integer  0x00000801(2049) magic number (MSB first--high endian) 
    val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt
    println(s"magicLabelNum=$magicLabelNum")
    //32 bit integer  60000            number of items 
    val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt
    println(s"numOfLabelItems=$numOfLabelItems")
    //打印测试数据
    for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) {
      println(s"$index is $e")
    }

    labelDS
  }

  /**
   * 加载MNIST train data数据
   */
  def loadMNISTTrainDataSet(): Array[Byte] = {
    val file = new File("D:/DM-datasets/MNIST/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte")
    val in = new FileInputStream(file)
    val trainingDS = new Array[Byte](file.length.toInt)
    using(new FileInputStream(file)) { source =>
      {
        in.read(trainingDS)
      }
    }

    //32 bit integer  0x00000803(2051) magic number 
    val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt
    println(s"magicNum=$magicNum")
    //32 bit integer  60000            number of items 
    val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
    println(s"numOfItems=$numOfItems")
    //32 bit integer  28               number of rows 
    val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt
    println(s"numOfRows=$numOfRows")
    //32 bit integer  28               number of columns 
    val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt
    println(s"numOfCols=$numOfCols")

    println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows))

    trainingDS
  }

}



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值