:下载数据集
http://yann.lecun.com/exdb/mnist/
:使用python查看原始数据集的图片
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
image1=[]
mnistData = open('D:/DM-datasets/MNIST/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte','rb')
mnistData.seek(0,0)
mnistData.read(16)
for i in range(0,28*28):
b=int(mnistData.read(1).encode('hex'),16)
image1.append(b)
print image1
arr = np.array(image1)
plt.imshow(arr.reshape((28, 28)),cmap=cm.Greys_r)
plt.axis('off')
plt.show()
:使用spark做logistic regression
package com.bbw5.ml.spark
import java.io.ByteArrayInputStream
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import javax.imageio.ImageIO
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.{ LogisticRegressionWithLBFGS, LogisticRegressionModel }
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.SparkConf
import org.apache.spark.mllib.evaluation.MulticlassMetrics
/**
* 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类
*
* @author baibaiw5
*/
object LogisticRegression4MNIST {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST")
val sc = new SparkContext(sparkConf)
val labelDS: Array[Byte] = loadMNISTTrainLabelDataSet()
val labels = labelDS.drop(8)
val trainingDS: Array[Byte] = loadMNISTTrainDataSet()
val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
val itemsBuffer = new ArrayBuffer[Array[Byte]]
for (i <- 0 until numOfItems) {
//16->16 + 28 * 28
//16 + 28 * 28->16 + 2*28 * 28
itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28)
}
println("itemsBuffer=" + itemsBuffer.length)
val itemsArray = itemsBuffer.toArray
val trainingData = labels.zip(itemsArray)
//打印测试数据概况
trainingData.groupBy(a => a._1).mapValues(b => b.size).foreach(println)
//只对0,1的图片进行分类
val rawData = trainingData.filter(p => p._1 == 0 || p._1 == 1).map(p => new LabeledPoint(p._1, Vectors.dense(p._2.map(c => c.toDouble))))
val data = sc.parallelize(rawData, 4)
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// 训练模型
val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(training)
// 对测试集进行验证
val predictionAndLabels = test.map {
case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
// 打印指标
val metrics = new BinaryClassificationMetrics(predictionAndLabels)
val auROC = metrics.areaUnderROC()
//Area under ROC = 0.7238888888888889
println("Area under ROC = " + auROC)
val metrics2 = new MulticlassMetrics(predictionAndLabels)
//Precision = 0.7257683215130024,Recall = 0.7257683215130024
println("Precision = " + metrics2.precision+",Recall = " + metrics2.recall)
}
/**
* 安全打开文件流方法
*/
def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B =
try {
f(resource)
} finally {
resource.close()
}
/**
* 加载MNIST train label数据
*/
def loadMNISTTrainLabelDataSet(): Array[Byte] = {
val file = new File("D:/DM-datasets/MNIST/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte")
val in = new FileInputStream(file)
val labelDS = new Array[Byte](file.length.toInt)
using(new FileInputStream(file)) { source =>
{
in.read(labelDS)
}
}
//32 bit integer 0x00000801(2049) magic number (MSB first--high endian)
val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt
println(s"magicLabelNum=$magicLabelNum")
//32 bit integer 60000 number of items
val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt
println(s"numOfLabelItems=$numOfLabelItems")
//打印测试数据
for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) {
println(s"$index is $e")
}
labelDS
}
/**
* 加载MNIST train data数据
*/
def loadMNISTTrainDataSet(): Array[Byte] = {
val file = new File("D:/DM-datasets/MNIST/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte")
val in = new FileInputStream(file)
val trainingDS = new Array[Byte](file.length.toInt)
using(new FileInputStream(file)) { source =>
{
in.read(trainingDS)
}
}
//32 bit integer 0x00000803(2051) magic number
val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt
println(s"magicNum=$magicNum")
//32 bit integer 60000 number of items
val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
println(s"numOfItems=$numOfItems")
//32 bit integer 28 number of rows
val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt
println(s"numOfRows=$numOfRows")
//32 bit integer 28 number of columns
val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt
println(s"numOfCols=$numOfCols")
println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows))
trainingDS
}
}