MAVEN
<dependency>
<groupId>com.microsoft.ml.spark</groupId>
<artifactId>mmlspark_2.11</artifactId>
<version>0.18.0</version>
</dependency>
<dependency>
<groupId>com.microsoft.ml.lightgbm</groupId>
<artifactId>lightgbmlib</artifactId>
<version>2.2.350</version>
</dependency>
测试数据
http://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip
hour.csv和day.csv都有如下属性,除了hour.csv文件中没有hr属性以外
- instant: 记录ID
- dteday : 时间日期
- season : 季节 (1:春季, 2:夏季, 3:秋季, 4:冬季)
- yr : 年份 (0: 2011, 1:2012)
- mnth : 月份 ( 1 to 12)
- hr : 当天时刻 (0 to 23)
- holiday : 当天是否是节假日(extracted from http://dchr.dc.gov/page/holiday-schedule)
- weekday : 周几
- workingday : 工作日 is 1, 其他 is 0.
- weathersit : 天气
- 1: Clear, Few clouds, Partly cloudy, Partly cloudy
- 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
- 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
- 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
- temp : 气温 Normalized temperature in Celsius. The values are divided to 41 (max)
- atemp: 体感温度 Normalized feeling temperature in Celsius. The values are divided to 50 (max)
- hum: 湿度 Normalized humidity. The values are divided to 100 (max)
- windspeed: 风速Normalized wind speed. The values are divided to 67 (max)
- casual: 临时用户数count of casual users
- registered: 注册用户数count of registered users
- cnt: 目标变量,每小时的自行车的租用量,包括临时用户和注册用户count of total rental bikes including both casual and registered
代码示例
package com.bigblue.lightgbm
import com.microsoft.ml.spark.lightgbm.{LightGBMRanker, LightGBMRankerModel}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* Created By TheBigBlue on 2020/3/4
* Description :
*/
object LightGBMRankerTest {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("test-lightgbm").master("local[2]").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
val originalData: DataFrame = spark.read.option("header", "true") //第一行作为Schema
.option("inferSchema", "true") //推测schema类型
// .csv("/home/hdfs/hour.csv")
.csv("file:///D:/Cache/ProgramCache/TestData/dataSource/lightgbm/hour.csv")
val labelCol = "workingday"
//离散列
val cateCols = Array[String]("season", "yr", "mnth", "hr")
// 连续列
val conCols: Array[String] = Array("temp", "atemp", "hum", "casual", "cnt")
//feature列
val vecCols = conCols ++ cateCols
import spark.implicits._
var inputDF = originalData.select(labelCol, vecCols: _*)
vecCols.foreach(col => {
inputDF = inputDF.withColumn(col, $"$col".cast(DoubleType))
})
inputDF = inputDF.withColumn(labelCol, $"$labelCol".cast(IntegerType))
//追加一列index列作为groupCol,不指定groupCol报错
import org.apache.spark.sql.functions._
inputDF = inputDF.withColumn("index", monotonically_increasing_id)
// val structType: StructType = inputDF.schema.add(StructField("index", LongType))
// val zipRDD: RDD[Row] = inputDF.rdd.zipWithIndex().map(tp => Row.merge(tp._1, Row(tp._2)))
// val fitDF = spark.createDataFrame(zipRDD, structType)
inputDF.show
val assembler = new VectorAssembler().setInputCols(vecCols).setOutputCol("features")
//必须设定groupCol
val classifier: LightGBMRanker = new LightGBMRanker().setNumIterations(100).setNumLeaves(31)
.setBoostFromAverage(false).setFeatureFraction(1.0).setMaxDepth(-1).setMaxBin(255)
.setLearningRate(0.1).setMinSumHessianInLeaf(0.001).setLambdaL1(0.0).setLambdaL2(0.0)
.setBaggingFraction(1.0).setBaggingFreq(0).setBaggingSeed(1).setObjective("lambdarank")
.setLabelCol(labelCol).setCategoricalSlotNames(cateCols).setFeaturesCol("features")
.setGroupCol("index").setBoostingType("gbdt")
val pipelineModel = new Pipeline().setStages(Array(assembler, classifier)).fit(inputDF)
val rankerModel = pipelineModel.stages(1).asInstanceOf[LightGBMRankerModel]
val importanceValues = rankerModel.getFeatureImportances("split")
//排序取前百分之
val filteredTuples = vecCols.zip(importanceValues).sortWith(_._2 > _._2)
.take((0.6 * vecCols.size).intValue())
//生成重要性df
var index = 0
val importanceRDD: Array[LightGBMRankerTest] = filteredTuples.map(tuple => {
index += 1
LightGBMRankerTest(index, tuple._1, tuple._2)
})
val importanceDF = spark.createDataFrame(importanceRDD)
importanceDF.show
//过滤后的特征数据
val filteredCols: Array[String] = filteredTuples.map(_._1)
val finalDF = inputDF.select(labelCol, filteredCols: _*)
finalDF.show
}
}
case class LightGBMRankerTest(id: Long, feature_name: String, value: Double)