最近研究了一下时间序列预测的使用,网上找了大部分的资源,都是使用python来实现的,使用python来实现虽然能满足大部分的需求,但是python有一点缺点按就是只能使用一台计算资源进行计算,如果数据量大的时候,就有可能不能胜任,虽然这种情况很少,但是还是有可能会发生,因此就查了一下spark有没有这方面的资料,没想到还真的有,使用spark集群进行计算速度方面提升明显。
首先非常感谢这位博主,我是在学习了他的代码之下才能更好的理解spark-timeseries的使用。
下面是我对代码的改进,主要是调整的是时间类型的通用性与arima模型能自定义pdq参数等,能通用大部分类型的时间。
TimeFormatUtils.java
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.HashMap;
import java.util.regex.Pattern;
public class TimeFormatUtils {
/**
* 获取时间类型格式
*
* @param timeStr
* @return
*/
public static String getDateType(String timeStr) {
HashMap dateRegFormat = new HashMap();
dateRegFormat.put("^\\d{4}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D*$", "yyyy-MM-dd HH:mm:ss");//2014年3月12日 13时5分34秒,2014-03-12 12:05:34,2014/3/12 12:5:34
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH:mm");//2014-03-12 12:05
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH");//2014-03-12 12
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd");//2014-03-12
dateRegFormat.put("^\\d{4}\\D+\\d{2}$", "yyyy-MM");//2014-03
dateRegFormat.put("^\\d{4}$", "yyyy");//2014
dateRegFormat.put("^\\d{14}$", "yyyyMMddHHmmss");//20140312120534
dateRegFormat.put("^\\d{12}$", "yyyyMMddHHmm");//201403121205
dateRegFormat.put("^\\d{10}$", "yyyyMMddHH");//2014031212
dateRegFormat.put("^\\d{8}$", "yyyyMMdd");//20140312
dateRegFormat.put("^\\d{6}$", "yyyyMM");//201403
try {
for (String key : dateRegFormat.keySet()) {
if (Pattern.compile(key).matcher(timeStr).matches()) {
String formater = "";
if (timeStr.contains("/"))
return dateRegFormat.get(key).replaceAll("-", "/");
else
return dateRegFormat.get(key);
}
}
} catch (Exception e) {
System.err.println("-----------------日期格式无效:" + timeStr);
e.printStackTrace();
}
return null;
}
public static String fromatData(String time, SimpleDateFormat format) {
try {
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
return formatter.format(format.parse(time));
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
TimeSeriesTrain.scala
import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.time.{ZoneId, ZonedDateTime}
import com.cloudera.sparkts._
import com.sendi.TimeSeries.Util.TimeFormatUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* 时间序列模型time-series的建立
*/
object TimeSeriesTrain {
/**
* 总方法调用
*/
def timeSeries(args: Array[String]) {
args.foreach(println)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
/**
* 1、初始化spark环境
*/
val sparkSession = SparkSession.builder
.master("local[4]").appName("SparkTest")
.enableHiveSupport() //创建支持HiveContext;
.getOrCreate()
/**
* 2、初始化参数
*/
//hive中的数据库名字
val databaseTableName = args(0)
//输入的列名必须是time data
val hiveColumnName = List(args(1).toString.split(","): _*)
//开始与结束时间
val startTime = args(2)
val endTime = args(3)
//获取时间类型
val sdf = new SimpleDateFormat(TimeFormatUtils.getDateType(startTime))
//时间跨度
val timeSpanType = args(4)
val timeSpan = args(5).toInt
//预测后面N个值
val predictedN = args(6).toInt
//存放的表名字
val outputTableName = args(7)
var listPDQ: List[String] = List("")
var period = 0
var holtWintersModelType = ""
//选择模型(holtwinters或者是arima)
val modelName = args(8)
//根据不同的类型赋值不同的参数
if (modelName.equals("arima")) {
listPDQ = List(args(9).toString.split(","): _*)
} else {
//季节性参数(12或者4)
period = args(9).toInt
//holtWinters选择模型:additive(加法模型)、Multiplicative(乘法模型)
holtWintersModelType = args(10)
}
/**
* 3、 读取数据源,最终转换成 {time key data} 这种类型的RDD格式
*/
val timeDataKeyDf = readHiveData(sparkSession, databaseTableName, hiveColumnName)
val zonedDateDataDf = timeChangeToDate(sparkSession, timeDataKeyDf, hiveColumnName, startTime, sdf)
/**
* 4、创建数据中时间的跨度(Create an daily DateTimeIndex):开始日期+结束日期+递增数
* 日期的格式要与数据库中time数据的格式一样