import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.sql.SparkSession
//ML目前支持的回归模型:
//Linear regression 线性回归
//Generalized linear regression 广义线性回归
//Decision tree regression 决策树回归
//Random forest regression 随机森林回归
//Gradient-boosted tree regression梯度提高树回归
//Survival regression生存回归
//Isotonic regression保存回归
object RegressionDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().config("spark.testing.memory", "2147480000")
.master("local[*]")
.appName("sda")
.getOrCreate()
//数据加载
val rawData = spark.read.format("csv").option("header",true).load("D:/hadoop/spark/bike/hour.csv")
rawData.show(6)
rawData.printSchema()
rawData.describe("dteday","holiday","weekday","temp").show()
// 上面的数据有很多字段是类型,使用回归算法时,需要通过OneHotEncoder把数据转换为二元向量,对一些字段或特征进行规范化
// 不过OneHotEncoder在2.3.0版本后不推荐使用,可以改用OneHotEncoderEstimator
// 数据预处理
// 1 特征选择
val data1 = rawData.select(rawData("season").cast("Double"),rawData("yr").cast("Double"),
rawData("mnth").cast("Double"),rawData("hr").cast("Double"),
rawData("holiday").cast("Double"),rawData("weekday").cast("Double"),
rawData("workingday").cast("Double"),rawData("weathersit").cast("Double"),
rawData("temp").cast("Double"),rawData("atemp").cast("Double"),
rawData("hum").cast("Double"),rawData("windspeed").cast("Double"),
rawData("cnt").cast("Double").alias("label"))
// 生成一个存放以上预测特征的特征向量
val featuresArray = Array("season","yr","mnth","hr","holiday","weekday","workingday","weathersit","temp",
"atemp","hum","windspeed")
// 把源数据组合成特征向量features
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
// 2 特征转换
// 这些特征大部分是分类特征,有些值是连续型,如气温、湿度等特征,使用决策树回归时,
// 技术上可以通过给定类别个数的最大值,自动识别哪些特征作为类别特征,哪些作为连续性特征。
// 这里把不同值小于或等于24的特征作为类别特征,大于24的视为连续性特征,并对分类特征索引化或数值化
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(24)
// 如果使用线性回归 需要对类别使用OneHotEncoder转换二元向量
val data2 = new OneHotEncoderEstimator().setInputCols(Array("season")).setOutputCols(Array("seasonVec"))
val data3 = new OneHotEncoderEstimator().setInputCols(Array("yr")).setOutputCols(Array("yrVec"))
val data4 = new OneHotEncoderEstimator().setInputCols(Array("mnth")).setOutputCols(Array("mnthVec"))
val data5 = new OneHotEncoderEstimator().setInputCols(Array("hr")).setOutputCols(Array("hrVec"))
val data6 = new OneHotEncoderEstimator().setInputCols(Array("holiday")).setOutputCols(Array("holidayVec"))
val data7 = new OneHotEncoderEstimator().setInputCols(Array("weekday")).setOutputCols(Array("weekdayVec"))
val data8 = new OneHotEncoderEstimator().setInputCols(Array("workingday")).setOutputCols(Array("workingdayVec"))
val data9 = new OneHotEncoderEstimator().setInputCols(Array("weathersit")).setOutputCols(Array("weathersitVec"))
// 3 组装
val Array(trainingData,testData) = data1.randomSplit(Array(0.7,0.3),12)
// val Array(trainingData_lr, testData_lr) =
// 使用决策树模型
// 参数:
// featuresCol:特征列名,默认features
// labelCol:标签列名,默认label
// predictionCol:预测结果列名,默认prediction
// maxDepth:最大深度,默认5
// maxBins:连续特征离散化的最大数量,以及选择每个节点分裂特征的方式,默认32
// minInstancesPerNode:分裂后自节点最少包含的实例数量,默认1
// minInfoGain:分裂节点时所需最小信息增益,默认0.0
// maxMemoryInMB:分配给直方图聚合的最大内存,默认256MB
// cacheNodeIds=False
// checkpointInterval:设置检查点间隔,或不设置检查点(-1),默认10
// impurity:计算信息增益的准则,默认variance
// seed:随机种子
// varianceCol:预测的有偏样本偏差的列名,默认无
val dt = new DecisionTreeRegressor().setLabelCol("label").setFeaturesCol("indexedFeatures").setMaxBins(64).setMaxDepth(15)
// 把决策树回归模型涉及的特征转换及模型训练组装在一个流水线上
val pipeline = new Pipeline().setStages(Array(assembler,featureIndexer,dt))
// 训练模型
val model = pipeline.fit(trainingData)
// 做出预测
val predictions = model.transform(testData)
// 评估模型
// RegressionEvaluator.setMetricName可以定义四种评估器:rmse默认,mse,r2,mae
val evaluator = new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("rmse")
// 决策树模型评估指标
val rmse = evaluator.evaluate(predictions)
println(rmse)
// 模型优化
}
}