开发前提
这里,需要对全杭州经纬度进行xgboost模型的训练预测,需要采用CountDownLatch并行开发.
CountDownLatch
代码模板
package com.nokia.zjbigdata.open.spark.impora
import java.util.concurrent.{CountDownLatch, Executors}
import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.LoggerFactory
/**
*
* program:
* description:上传训练模型
* Author:
* Date:2021/1/8 14:52
*
*/
object demoXgBoost24 {
def main(args: Array[String]): Unit = {
if (args.length != 2) {
println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
// sys.exit(1)
}
val LOG = LoggerFactory.getLogger("")
val conf = new SparkConf()
conf.set("spark.yarn.am.waitTime", "3600")
conf.set("spark.sql.broadcastTimeout", "3600")
conf.set("spark.kryo.registrationRequired", "false")
val spark = SparkSession.builder().appName("XGBoost4J-Spark Pipeline Example").enableHiveSupport().config(conf).getOrCreate()
val sc: SparkContext = spark.sparkContext
spark.sql("CREATE TEMPORARY FUNCTION key AS 'SC_LON_LST_UDF'")
val assembler = new VectorAssembler().setInputCols(
Array("sc_longitude", "sc_latitude", "sctadv", "sc_pci", "sc_freq", "scrsrp", "scrsrq"
, "nc1pci", "nc1freq", "nc1rsrp", "nc1rsrq"
, "nc2pci", "nc2freq", "nc2rsrp", "nc2rsrq"
, "nc3pci", "nc3freq", "nc3rsrp", "nc3rsrq"
)).setOutputCol("features")
val booster = new XGBoostRegressor(
Map("eta" -> 0.2f,
"subsample" -> 0.9,
"max_depth" -> 6,
"missing" -> -999,
"objective" -> "reg:squarederror",
"colsample_bytree" -> 0.8,
"colsample_bylevel" -> 0.8,
"colsample_bynode" -> 0.8,
"num_round" -> 500,
"num_workers" -> 6,
"tree_method" -> "auto"
)
)
val booster_lat = new XGBoostRegressor(
Map("eta" -> 0.5f,
"subsample" -> 0.9,
"max_depth" -> 6,
"missing" -> -999,
"objective" -> "reg:squarederror",
"colsample_bytree" -> 0.8,
"colsample_bylevel" -> 0.8,
"colsample_bynode" -> 0.8,
"num_round" -> 500,
"num_workers" -> 6,
"tree_method" -> "auto"
)
)
booster.setFeaturesCol("features")
booster.setLabelCol("longitude")
booster_lat.setLabelCol("latitude")
booster_lat.setFeaturesCol("features")
//设置线程总数
val service = Executors.newFixedThreadPool(40)
//设置分区任务数
val latch = new CountDownLatch(1542)
for (k <-5 to 1547 ){
val runnable: Runnable = new Runnable {
override def run(): Unit = {
try {
val rawInput: DataFrame = spark.sql(
s"""
|select
|*
|from mro_ns2_hive_db.tmp_mdt_hw_zx_njydwxt_xgboots_xl_partitions_02
|where key=$k
""".stripMargin)
val train = assembler.transform(rawInput.na.fill(-999)).select("features", "longitude", "latitude")
val model = booster.fit(train)
val model_lat = booster_lat.fit(train)
model.write.overwrite().save("xgBoost/trains_lon_models/trains_models_lon_" + k)
model_lat.write.overwrite().save("xgBoost/trains_lat_models/trains_models_lat_" + k)
} catch {
case e: Exception =>
e.printStackTrace()
} finally {
//当前线程调用此方法,则计数减一
latch.countDown
}
}
}
service.execute(runnable)
}
try{
latch.await()
}catch {
case e: Exception =>
e.printStackTrace()
}
// val rawTest: DataFrame = spark.sql(
// s"""
// |select * from mro_ns2_hive_db.tmp_mdt_hw_zx_njydwxt_xgboots_cs_partitions_test
// """.stripMargin)
// val test = assembler.transform(rawTest.na.fill(-999)).select("features", "longitude","latitude")
//
//
// val model_read =XGBoostRegressionModel.load("xgBoost/trains_lon_models/trains_models_lon_1")
// val model_lat_read = XGBoostRegressionModel.load("xgBoost/trains_lat_models/trains_models_lat_1")
// val prediction: DataFrame = model_read.transform(test)
// val prediction_lon: DataFrame = prediction.withColumnRenamed("prediction","prediction_longitude")
// val prediction_lat: DataFrame = model_lat_read.transform(prediction_lon)
// prediction_lat.createOrReplaceTempView("prediction_lat")
//
// spark.sql(
// """
// |insert into table mro_ns2_hive_db.tmp_mro_msisdn_njy_xgboots_text_4
// |select
// |(6378137*2*ASIN(SQRT(POWER(SIN((latitude-prediction)*ACOS(-1)/360),2)+
// |COS(latitude*ACOS(-1)/180)*COS(prediction*ACOS(-1)/180)*POWER(SIN((longitude-prediction_longitude)*ACOS(-1)/360),2)))) as avg
// |from prediction_lat
// |where longitude>100
// """.stripMargin)
// }
spark.stop()
}
}
/*
spark-submit \
--class com.nokia.zjbigdata.open.spark.impora.demoXgBoost24 \
--master yarn \
--deploy-mode cluster \
--driver-memory 8g \
--executor-memory 4g \
--executor-cores 3 \
--num-executors 150 \
--jars xgboost.jar,xgboost4j-0.90.jar,xgboost4j_spark_2.3.2_0.90.jar,fastjson-1.2.47.jar,UDF.jar \
--name mro_demoXgBoost \
xgboost.jar
*/
数组par.foreach
val train_keys: Array[String] = rawInput_xl.rdd.map(_.get(0).toString).collect()
val text_keys: Array[String] = rawTest_cs.rdd.map(_.get(0).toString).collect()
//遍历train_keys数组
train_keys.par.foreach{
//这里输入你要执行的任务
}