数据来源
下载地址
本项目利用线性回归模型来进行预测,默认假设共享单车出租的数量和数据中的特征存在着线性的关系,所以利用线性模型预测
代码如图:
package com.huawei.sharebicyle
import org.apache.spark.SparkConf
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderEstimator, OneHotEncoderModel, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.{DecisionTreeRegressor, LinearRegression}
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.sql.{DataFrame, SparkSession}
object ShareBycile {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local").setAppName("ShareBycycle")
val spark = SparkSession.builder().config(conf).getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val rawdata = spark.read.format("csv").option("header",true)
.load("hour.csv")
rawdata.show(4)
rawdata.printSchema()
rawdata.describe("dteday","holiday","weekday","temp").show()
//对数据的原始格式进行转换
val data1 = rawdata.select(
rawdata("season").cast("Double"),
rawdata("yr").cast("Double"),
rawdata("mnth").cast("Double"),
rawdata("hr").cast("Double"),
rawdata("holiday").cast("Double"),
rawdata("weekday").cast("Double"),
rawdata("workingday").cast("Double"),
rawdata("weathersit").cast("Double"),
rawdata("temp").cast("Double"),
rawdata("atemp").cast("Double"),
rawdata("hum").cast("Double"),
rawdata("windspeed").cast("Double"),
rawdata("cnt").cast("Double").alias("label"))
data1.show(10) rawdata("temp").cast("Double"),
println("__-------------------------")
//生成衣蛾存放以上预测特征的特征向量
val featuresArray =
Array("season","yr","mnth","hr","holiday","weekday","workingday","weathersit",
"temp","atemp","hum","windspeed")
val assembler = new VectorAssembler().setInputCols(featuresArray)
.setOutputCol("features")
val featureIndexer: VectorIndexer = new VectorIndexer().setInputCol("features")
.setOutputCol("indexedFeatures").setMaxCategories(24)
val data2Es: OneHotEncoderEstimator = new OneHotEncoderEstimator().setInputCols(Array("season"
,"yr","mnth","hr","holiday","weekday","workingday","weathersit"))
.setOutputCols(Array("seasonVec","yrVec","mnthVec","hrVec","holidayVec",
"weekdayVec","workingdayVec","weathersitVec"))
val data2: DataFrame = data2Es.fit(data1).transform(data1)//将数据转化为onehot类型
data2.printSchema()
data2.show(10)
val Array(trainingData_lr,testData_lr) = data2.randomSplit(Array(0.8,0.2),seed=1234)
//一会待验证
//
val assembler_lr1 = new VectorAssembler()
.setInputCols(Array("seasonVec","yrVec","mnthVec","hrVec",
"holidayVec","weekdayVec","workingdayVec","weathersitVec",
"temp","atemp","hum","windspeed"))
.setOutputCol("features_lr1")
//导入需要sql转化的包
import org.apache.spark.ml.feature.SQLTransformer
val sqlTrans = new SQLTransformer().setStatement(
"SELECT *,SQRT(label) as label1 from __THIS__"
)
//设置线性归回的参数
val lr1 = new LinearRegression()
.setFeaturesCol("features_lr1")
.setLabelCol("label1")
.setFitIntercept(true)
//设置流水线
val pipeline_lr1 = new Pipeline().setStages(Array(assembler_lr1,sqlTrans,lr1))
//创建参数网络利于后面的网格搜索来选取最优的超参数
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(lr1.elasticNetParam,Array(0.0,0.8,1)) //正则化弹性网络的参数
.addGrid(lr1.regParam,Array(0.1,0.3,0.5)) //正则化系数
.addGrid(lr1.maxIter,Array(20,30)).build()//迭代次数
//选择 prediction label1,计算测试误差
val evaluator_lr1 = new RegressionEvaluator()
.setLabelCol("label1")
.setPredictionCol("prediction")
.setMetricName("rmse")
//利用交叉验证方法
val trainValidationSplit = new TrainValidationSplit()
.setEstimator(pipeline_lr1)
.setEvaluator(evaluator_lr1)
.setEstimatorParamMaps(paramGrid)
.setTrainRatio(0.8)
//训练模型并自动选择最优参数
//TODO 一会看下数据是怎么分的
val lrModel1 = trainValidationSplit.fit(trainingData_lr)
//查看模型的全部参数
lrModel1.getEstimatorParamMaps.foreach { println } //参数组合
println("--------------------------")
lrModel1.getEstimator.extractParamMap() //查看评估参数
lrModel1.getEvaluator.isLargerBetter
//用最好的参数组合,做出预测
val predictions_lr1 = lrModel1.transform(testData_lr)
val rmse_lr1 = evaluator_lr1.evaluate(predictions_lr1)
//显示转化后特纸质的前五行信息
predictions_lr1.select("features_lr1","label","label1","prediction").show(5)
}
}
对应的pom文件,可以详见我的另一个博客,spark-ml的pom文件