package main
import java.io.{File, PrintWriter}
import java.text.SimpleDateFormat
import java.util.{Calendar, Date}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}
import utils.RedisUtil
import scala.collection.mutable.ArrayBuffer
object Train {
def main(args: Array[String]): Unit = {
// 1 将本次评估结果保存到文件
// import java.io.PrintWriter
// import java.io.File
val writer = new PrintWriter(new File("model_training.txt"))
// 2 配置spark
// import org.apache.spark.SparkConf
val sparkConf = new SparkConf()
.setMaster("local[2]")
.setAppName("TrafficTrain")
// 3 实例化SparkContext
val sc = new SparkContext(sparkConf)
// 4 redis
val dbIndex = 1
// 5 从redis中获取数据
// import utils.RedisUtil
val jedis = RedisUtil.pool.getResource
jedis.select(dbIndex)
// 6 设置建模对象(随便选的)
val monitorIDs = List("0005", "0015")
// 7 对上面两个监测点进行建模,但是这两个监测点可能需要其他监测点的数据信息
val monitorRelations = Map[String, Array[String]](
"0005" -> Array("0003", "0004", "0005", "0006", "0007"),
"0015" -> Array("0013", "0014", "0015", "0016", "0017"))
// 8 遍历所有监测点,进行数据建模
// 本例应该是:0005,0015
monitorIDs.map(monitorID => {
// 得到Array(相关监测点)
val monitorRelationArray = monitorRelations(monitorID)
// 处理时间,初始化
val currentDate = Calendar.getInstance().getTime
// 设置时间格式化
// 当前时间
val hourMinuteSDF = new SimpleDateFormat("HHmm")
// 当前年月日
val dateSDF = new SimpleDateFormat("yyyyMMdd")
val dateOfString = dateSDF.format(currentDate)
// 根据相关监测点,取得当日的所有的监测点的平均车速
val relationInfo = monitorRelationArray.map(monitorID => {
(monitorID, jedis.hgetAll(dateOfString + "_" + monitorID))
})
// 使用n个小时内的数据进行建模
val hours = 1
// 有监督学习
// 创建3个数组,一个数组用于存放特征向量,一数组用于存放Label向量,一个数组用于存放前两者之间的关联
// 用于存放特征向量和特征结果的映射关系
// import scala.collection.mutable.ArrayBuffer
// import org.apache.spark.mllib.regression.LabeledPoint
val dataTrain = ArrayBuffer[LabeledPoint]()
val dataX = ArrayBuffer[Double]()
val dataY = ArrayBuffer[Double]()
// 将时间拉回到1小时之前,倒序,拉回单位:分钟
for (i <- Range(60 * hours, 2, -1)) {
dataX.clear()
dataY.clear()
// 以下内容包含:线性滤波
for (index <- 0 to 2) {
// 当前毫秒数 - 1个小时之前的毫秒数+1个小时之前的后0分钟,1分钟,2分钟的毫秒数(第3分钟作为Label向量)
val oneMoment = currentDate.getTime - 60 * i * 1000 + 60 * index * 1000
// 拼装除当前(当前for循环这一次的时间)的小时分钟数
// import java.util.Date
val oneHM = hourMinuteSDF.format(new Date(oneMoment))
// 取得该时刻下里面数据
// 取出的数据形式距离:(0005,{1033=93_2,1034=1356_30})
for ((k, v) <- relationInfo) {
// 如果index==2,意味着前三分钟的数据已经组装到了dataX中,那么下一时刻的数据,如果是目标卡口,则需要存放于dataY中
if (k == monitorID && index == 2) {
// 第四分钟数据
val nextMoment = oneMoment + 60 * 1000
val nextHM = hourMinuteSDF.format(new Date(nextMoment))
// 判断是否有数据
if (v.containsKey(nextHM)) {
val speedAndCarCount = v.get(nextHM).split("_")
// 得到第4分钟的平均车速
val valueY = speedAndCarCount(0).toFloat / speedAndCarCount(1).toFloat
dataY += valueY
}
}
// 组装前3分钟的dataX
if (v.containsKey(oneHM)) {
val speedAndCarCount = v.get(oneHM).split("_")
// 得到当前这一分钟的特征值
val valueX = speedAndCarCount(0).toFloat / speedAndCarCount(1).toFloat
dataX += valueX
} else {
dataX += 60.0F
}
}
}
// 准备训练模型
// 先将dataX和dataY映射于一个LabeledPoint对象中
if (dataY.toArray.length == 1) {
// 答案的平均车速
val label = dataY.toArray.head
val record = LabeledPoint(
if (label / 10 < 6) (label / 10).toInt
else 6,
// import org.apache.spark.mllib.linalg.Vectors
Vectors.dense(dataX.toArray))
dataTrain += record
}
}
// 将数据集写入到文件中方便查看
dataTrain.foreach(record => {
println(record)
writer.write(record.toString() + "\n")
})
// 开始组装训练集和测试集
val rddData = sc.parallelize(dataTrain)
// 切分数据集
val randomSplits = rddData.randomSplit(Array(0.6, 0.4), 11L)
// 训练集
val trainingData = randomSplits(0)
// 测试集
val testData = randomSplits(1)
// 使用训练集进行建模
val model = new LogisticRegressionWithLBFGS().setNumClasses(7).run(trainingData)
// 完成建模之后,使用测试集,评估模型精确度
val predictionAndLabels = testData.map {
case LabeledPoint(label, features) =>
val prediction = model.predict(features)
(prediction, label)
}
// 得到当前检测点model的评估值
val metries = new MulticlassMetrics(predictionAndLabels)
// accuracy取值范围0.0-0.1
val accuracy = metries.accuracy
println("评估值:" + accuracy)
writer.write(accuracy.toString + "\r\n")
// 设置评估阈值:超过多少精确度,则保存模型
if (accuracy > 0.0) {
// 将模型保存到HDFS,注意用户权限
val hdfsPath =
"hdfs://centosa:9000/traffic/model/" + monitorID + "_" +
new SimpleDateFormat("yyyyMMddHHmmss").format(new Date(System.currentTimeMillis()))
model.save(sc, hdfsPath)
jedis.hset("model", monitorID, hdfsPath)
}
})
RedisUtil.pool.returnResource(jedis)
writer.flush
writer.close()
}
}
run Train.scala
(需要启动producer和SparkConsumer)
控制台输出评估值
得到特征向量:
模型:
(省略大部分)
HDFS上内容: