本次教程介绍的是,利用python调用scikit-learn库的神经网络模型,进行时间序列预测。
不同于传统的机器学习模型,不需要特征,只需要连续时间内的target,就可以预测未来时间内的target
这个问题被成为时间序列预测问题,传统的方法是利用ARIMA或者SPSS。但是我觉得ARIMA对开发者要求比较高,经常出现预测效果不好的问题。
SPSS不适合进行批量预测,这个方法对开发者要求不高,而且预测效果也还可以。
这里的背景是预测2000个shop未来6周的销售量。训练数据是2015-7-1至2016-10-30的流量(天池IJICAI)
数据下载地址 https://pan.baidu.com/s/1miz8CrA
github参考 https://github.com/wangtuntun/IJCAI_nnet
这里分几个部分进行:
1 利用spark统计每个shop每天的购买量(数据量太大,而且不会用python的datagrame的group操作)
2 对数据进行清洗,保证每个shop每天都有flow。(时间序列处理要求每天都得有)
3 构建神经网络模型
4 调用模型进行预测
首先第一步 利用spark统计每个shop每天的购买量(数据量太大,而且不会用python的datagrame的group操作)
/**
* Created by wangtuntun on 17-3-4.
* 实现的主要功能是计算出每个商家每天的流量:(shop_id,DS,flow)
*/
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
object clean {
def main(args: Array[String]) {
//设置环境
val conf=new SparkConf().setAppName("tianchi").setMaster("local")
val sc=new SparkContext(conf)
val sqc=new SQLContext(sc)
val user_pay_raw=sc.textFile("/home/wangtuntun/IJCAI/Data/user_pay.txt")
val user_pay_split=user_pay_raw.map(_.split(","))
val user_transform =user_pay_split.map{ x=> //数据转换
val userid=x(0)
val shop_id=x(1)
val ts=x(2)
val ts_split=ts.split(" ")
val year_month_day=ts_split(0).split("-")
val year=year_month_day(0)
val month=year_month_day(1)
val day=year_month_day(2)
// (shop_id,userid,year,month,day)
(shop_id,userid,ts_split(0))
}
val df=sqc.createDataFrame(user_transform) // 生成一个dataframe
val df_name_colums=df.toDF("shop_id","userid","DS") //给df的每个列取名字
df_name_colums.registerTempTable("user_pay_table") //注册临时表
val sql="select shop_id ,count(userid),DS from user_pay_table group by shop_id,DS order by shop_id desc,DS"
val rs =sqc.sql(sql)
rs.foreach(x=>println(x))
// user_transform.saveAsTextFile("/home/wangtuntun/test_file4.txt")
val rs_rdd=rs.map( x => x(0) + ","+ x(2).toString + "," + x(1) ) //rs转为rdd
rs_rdd.coalesce(1,true).saveAsTextFile("/home/wangtuntun/ds_flow_raw_data.txt")
sc.stop();
}
}
第二步 对数据进行清洗,保证每个shop每天都有flow。(时间序列处理要求每天都得有)
# encoding=utf-8
'''
原始的id_date_flow数据是由dataframe.sql(groupby user_id)完成,不能保证所有用户的所有天数据都有,所以进行一次清洗和填充
'''
from datetime import timedelta
import datetime
start_time_str = "2015-07-01"
end_time_str = "2016-10-30"
start_time = dat