:ROC曲线概念
http://blog.csdn.net/abcjennifer/article/details/7359370
:Recall-Precision概念
http://blog.csdn.net/pirage/article/details/9851339
:下载MNIST数据集
http://yann.lecun.com/exdb/mnist/
:Logistic Regression:从入门到精通
http://www.tianyancha.com/research/LR_intro.pdf
:加载MNIST数据类
package com.bbw5.ml.spark.data
import java.io.File
import java.io.FileInputStream
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.Vector
/**
* http://yann.lecun.com/exdb/mnist/
*
*/
class MNISTData(val dataDir: String, val numCount: Int = 2) {
val trainLabeFileName = "/train-labels-idx1-ubyte/train-labels.idx1-ubyte"
val trainImageFileName = "/train-images-idx3-ubyte/train-images.idx3-ubyte"
val testLabeFileName = "/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"
val testImageFileName = "/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte"
/**
* 安全打开文件流方法
*/
def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B =
try {
f(resource)
} finally {
resource.close()
}
def loadTrainLabelData(): Array[Byte] = {
loadLabelData(trainLabeFileName)
}
def loadTestLabelData(): Array[Byte] = {
loadLabelData(testLabeFileName)
}
/**
* 加载MNIST train label数据
*/
def loadLabelData(filename: String): Array[Byte] = {
val file = new File(dataDir + filename)
val in = new FileInputStream(file)
val labelDS = new Array[Byte](file.length.toInt)
using(new FileInputStream(file)) { source =>
{
in.read(labelDS)
}
}
//32 bit integer 0x00000801(2049) magic number (MSB first--high endian)
val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt
println(s"magicLabelNum=$magicLabelNum")
//32 bit integer 60000 number of items
val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt
println(s"numOfLabelItems=$numOfLabelItems")
//打印测试数据
for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) {
println(s"image$index is $e")
}
labelDS
}
def loadTrainImageData(): Array[Byte] = {
loadImageData(trainImageFileName)
}
def loadTestImageData(): Array[Byte] = {
loadImageData(testImageFileName)
}
/**
* 加载MNIST train data数据
*/
def loadImageData(filename: String): Array[Byte] = {
val file = new File(dataDir + filename)
val in = new FileInputStream(file)
val trainingDS = new Array[Byte](file.length.toInt)
using(new FileInputStream(file)) { source =>
{
in.read(trainingDS)
}
}
//32 bit integer 0x00000803(2051) magic number
val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt
println(s"magicNum=$magicNum")
//32 bit integer 60000 number of items
val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
println(s"numOfItems=$numOfItems")
//32 bit integer 28 number of rows
val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt
println(s"numOfRows=$numOfRows")
//32 bit integer 28 number of columns
val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt
println(s"numOfCols=$numOfCols")
println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows))
trainingDS
}
def loadTrainData(): Array[(Double, Vector)] = {
loadData(loadTrainLabelData, loadTrainImageData)
}
def loadTestData(): Array[(Double, Vector)] = {
loadData(loadTestLabelData, loadTestImageData)
}
def loadData(loadLabelFunc: () => Array[Byte], loadImageFunc: () => Array[Byte]): Array[(Double, Vector)] = {
val labelDS: Array[Byte] = loadLabelFunc()
val labels = labelDS.drop(8)
val trainingDS: Array[Byte] = loadImageFunc()
val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt
val itemsBuffer = new ArrayBuffer[Array[Byte]]
for (i <- 0 until numOfItems) {
//16->16 + 28 * 28
//16 + 28 * 28->16 + 2*28 * 28
itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28)
}
println("numOfImages=" + itemsBuffer.length)
val itemsArray = itemsBuffer.toArray
val data = labels.zip(itemsArray)
//打印测试数据概况
println("image digit count:")
data.groupBy(a => a._1).mapValues(b => b.size).foreach(println)
//only 0/1 image
data.filter(p => p._1 < numCount).map(p => (p._1.toDouble, Vectors.dense(p._2.map(c => c.toDouble))))
}
}
object MNISTData {
def loadTrainData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {
new MNISTData(dataDir, numCount).loadTrainData()
}
def loadTestData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {
new MNISTData(dataDir, numCount).loadTestData()
}
}
:使用ML API进行分类
package com.bbw5.ml.spark
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import com.bbw5.ml.spark.data.MNISTData
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.tuning.TrainValidationSplit
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary
import org.apache.spark.ml.classification.LogisticRegressionModel
import java.util.Date
/**
* 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类
*
* @author baibaiw5
*/
object LogisticRegression4MNIST {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST")
val sc = new SparkContext(sparkConf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
tvSplit(sc, sqlContext)
}
def printResult(bestModel: LogisticRegressionModel) {
bestModel.save("D:/Develop/Model/MNIST-LR-"+System.currentTimeMillis())
println("bestModel.params:" + bestModel.extractParamMap)
val trainingSummary = bestModel.summary
// Obtain the objective per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
println("print lost in every step:")
objectiveHistory.foreach(loss => println(loss))
val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
println("print roc in every step:")
binarySummary.roc.show(binarySummary.roc.count.toInt)
println("print recall,precision every step:")
binarySummary.pr.show(binarySummary.pr.count.toInt)
//0.9409024010447935
println("areaUnderROC="+binarySummary.areaUnderROC)
}
def tvSplit(sc: SparkContext, sqlContext: SQLContext) {
val dataDir = "I:/DM-dataset/MNIST/"
import sqlContext.implicits._
val training = sc.parallelize(MNISTData.loadTrainData(dataDir), 4).toDF("label", "features").cache()
training.describe("label").show
val test = sc.parallelize(MNISTData.loadTestData(dataDir), 4).toDF("label", "features").cache()
test.describe("label").show
val lr = new LogisticRegression()
val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0001, 0.01, 1.0)).addGrid(lr.fitIntercept).addGrid(lr.maxIter, Array(100)).addGrid(lr.elasticNetParam, Array(0.1, 0.5, 1.0)).build()
// 80% of the data will be used for training and the remaining 20% for validation.
val trainValidationSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8)
// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training)
// Make predictions on test data. model is the model with combination of parameters
// that performed best.
val testDF = model.transform(test)
testDF.select("label", "prediction").show()
testDF.groupBy("label", "prediction").count().show()
printResult(model.bestModel.asInstanceOf[LogisticRegressionModel])
}
}
:在测试集上测试后的结果AUC为 0.9409024010447935,准确率为97%
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
| 1.0| 1.0| 1106|
| 0.0| 0.0| 962|
| 0.0| 1.0| 18|
| 1.0| 0.0| 29|
+-----+----------+-----+