SparkML -- LightGBM On Spark 导出pmml示例

MAVEN

<dependency>
  <groupId>org.jpmml</groupId>
    <artifactId>jpmml-sparkml</artifactId>
    <version>1.5.0</version>
<!--            <exclusions>-->
<!--                <exclusion>-->
<!--                    <groupId>org.jpmml</groupId>-->
<!--                    <artifactId>jpmml-converter</artifactId>-->
<!--                </exclusion>-->
<!--            </exclusions>-->
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>jpmml-lightgbm</artifactId>
    <version>1.2.3</version>
</dependency>

测试数据

http://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip

hour.csv和day.csv都有如下属性,除了hour.csv文件中没有hr属性以外

  • instant: 记录ID
  • dteday : 时间日期
  • season : 季节 (1:春季, 2:夏季, 3:秋季, 4:冬季)
  • yr : 年份 (0: 2011, 1:2012)
  • mnth : 月份 ( 1 to 12)
  • hr : 当天时刻 (0 to 23)
  • holiday : 当天是否是节假日(extracted from http://dchr.dc.gov/page/holiday-schedule)
  • weekday : 周几
  • workingday : 工作日 is 1, 其他 is 0.
  • weathersit : 天气
  • 1: Clear, Few clouds, Partly cloudy, Partly cloudy
  • 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
  • 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
  • 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
  • temp : 气温 Normalized temperature in Celsius. The values are divided to 41 (max)
  • atemp: 体感温度 Normalized feeling temperature in Celsius. The values are divided to 50 (max)
  • hum: 湿度 Normalized humidity. The values are divided to 100 (max)
  • windspeed: 风速Normalized wind speed. The values are divided to 67 (max)
  • casual: 临时用户数count of casual users
  • registered: 注册用户数count of registered users
  • cnt: 目标变量,每小时的自行车的租用量,包括临时用户和注册用户count of total rental bikes including both casual and registered

代码示例,以二分类为例

package com.bigblue.lightgbm

import java.io.FileOutputStream

import com.bigblue.utils.LightGBMUtils
import com.microsoft.ml.spark.lightgbm.{LightGBMClassificationModel, LightGBMClassifier}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.jpmml.lightgbm.GBDT
import org.jpmml.model.MetroJAXBUtil

/**
 * Created By TheBigBlue on 2020/3/6
 * Description :
 */
object LightGBMClassificationTest {

  def main(args: Array[String]): Unit = {

    val spark: SparkSession = SparkSession.builder().appName("test-lightgbm").master("local[2]").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    var originalData: DataFrame = spark.read.option("header", "true") //第一行作为Schema
      .option("inferSchema", "true") //推测schema类型
      //      .csv("/home/hdfs/hour.csv")
      .csv("file:///D:/Cache/ProgramCache/TestData/dataSource/lightgbm/hour.csv")

    val labelCol = "workingday"
    //离散列
    val cateCols = Array("season", "yr", "mnth", "hr")
    // 连续列
    val conCols: Array[String] = Array("temp", "atemp", "hum", "casual", "cnt")
    //feature列
    val vecCols = conCols ++ cateCols

    import spark.implicits._
    vecCols.foreach(col => {
      originalData = originalData.withColumn(col, $"$col".cast(DoubleType))
    })
    originalData = originalData.withColumn(labelCol, $"$labelCol".cast(IntegerType))

    val assembler = new VectorAssembler().setInputCols(vecCols).setOutputCol("features")

    val classifier: LightGBMClassifier = new LightGBMClassifier().setNumIterations(100).setNumLeaves(31)
      .setBoostFromAverage(false).setFeatureFraction(1.0).setMaxDepth(-1).setMaxBin(255)
      .setLearningRate(0.1).setMinSumHessianInLeaf(0.001).setLambdaL1(0.0).setLambdaL2(0.0)
      .setBaggingFraction(1.0).setBaggingFreq(0).setBaggingSeed(1).setObjective("binary")
      .setLabelCol(labelCol).setCategoricalSlotNames(cateCols).setFeaturesCol("features")
      .setBoostingType("gbdt")	//rf、dart、goss

    val pipeline: Pipeline = new Pipeline().setStages(Array(assembler, classifier))

    val Array(tr, te) = originalData.randomSplit(Array(0.7, .03), 666)
    val model = pipeline.fit(tr)
    val modelDF = model.transform(te)
    val evaluator = new BinaryClassificationEvaluator().setLabelCol(labelCol).setRawPredictionCol("prediction")
    println(evaluator.evaluate(modelDF)) 

	//增加导出pmml
	val classificationModel = model.stages(1).asInstanceOf[LightGBMClassificationModel]
    LightGBMUtils.saveToPmml(classificationModel.getModel, "D://Download/classificationModel.xml")
  }
}
package com.bigblue.utils

import java.io.{ByteArrayInputStream, FileOutputStream}

import com.microsoft.ml.spark.lightgbm.LightGBMBooster
import org.jpmml.lightgbm.LightGBMUtil
import org.jpmml.model.MetroJAXBUtil

/**
 * Created By TheBigBlue on 2020/3/20
 * Description : 
 */
object LightGBMUtils {

  def saveToPmml(booster: LightGBMBooster, path: String): Unit = {
    try {
      val gbdt = LightGBMUtil.loadGBDT(new ByteArrayInputStream(booster.model.getBytes))
      import scala.collection.JavaConversions.mapAsJavaMap
      val pmml = gbdt.encodePMML(null, null, Map("compact" -> true))
      MetroJAXBUtil.marshalPMML(pmml, new FileOutputStream(path))
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }

}

结果

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值