Spark 并行开发

开发前提

这里,需要对全杭州经纬度进行xgboost模型的训练预测,需要采用CountDownLatch并行开发.

CountDownLatch

CountDownLatch介绍

代码模板

package com.nokia.zjbigdata.open.spark.impora

import java.util.concurrent.{CountDownLatch, Executors}

import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.LoggerFactory

/**
  *
  * program: 
  * description:上传训练模型
  * Author:
  * Date:2021/1/8 14:52
  *
  */
object demoXgBoost24 {
  def main(args: Array[String]): Unit = {
    if (args.length != 2) {
      println("Usage: SparkMLlibPipeline input_path native_model_path pipeline_model_path")
      //      sys.exit(1)
    }


    val LOG = LoggerFactory.getLogger("")
    val conf = new SparkConf()
    conf.set("spark.yarn.am.waitTime", "3600")
    conf.set("spark.sql.broadcastTimeout", "3600")
    conf.set("spark.kryo.registrationRequired", "false")

    val spark = SparkSession.builder().appName("XGBoost4J-Spark Pipeline Example").enableHiveSupport().config(conf).getOrCreate()

    val sc: SparkContext = spark.sparkContext


    spark.sql("CREATE TEMPORARY FUNCTION key AS 'SC_LON_LST_UDF'")

    val assembler = new VectorAssembler().setInputCols(
      Array("sc_longitude", "sc_latitude", "sctadv", "sc_pci", "sc_freq", "scrsrp", "scrsrq"
        , "nc1pci", "nc1freq", "nc1rsrp", "nc1rsrq"
        , "nc2pci", "nc2freq", "nc2rsrp", "nc2rsrq"
        , "nc3pci", "nc3freq", "nc3rsrp", "nc3rsrq"
      )).setOutputCol("features")

    val booster = new XGBoostRegressor(
      Map("eta" -> 0.2f,
        "subsample" -> 0.9,
        "max_depth" -> 6,
        "missing" -> -999,
        "objective" -> "reg:squarederror",
        "colsample_bytree" -> 0.8,
        "colsample_bylevel" -> 0.8,
        "colsample_bynode" -> 0.8,
        "num_round" -> 500,
        "num_workers" -> 6,
        "tree_method" -> "auto"
      )
    )

    val booster_lat = new XGBoostRegressor(
      Map("eta" -> 0.5f,
        "subsample" -> 0.9,
        "max_depth" -> 6,
        "missing" -> -999,
        "objective" -> "reg:squarederror",
        "colsample_bytree" -> 0.8,
        "colsample_bylevel" -> 0.8,
        "colsample_bynode" -> 0.8,
        "num_round" -> 500,
        "num_workers" -> 6,
        "tree_method" -> "auto"
      )
    )

    booster.setFeaturesCol("features")
    booster.setLabelCol("longitude")
    booster_lat.setLabelCol("latitude")
    booster_lat.setFeaturesCol("features")

		//设置线程总数
    val service = Executors.newFixedThreadPool(40)
		//设置分区任务数
    val latch = new CountDownLatch(1542)

    for (k <-5 to 1547 ){
      val runnable: Runnable = new Runnable {
        override def run(): Unit = {
          try {
            val rawInput: DataFrame = spark.sql(
              s"""
                 |select
                 |*
                 |from mro_ns2_hive_db.tmp_mdt_hw_zx_njydwxt_xgboots_xl_partitions_02
                 |where key=$k
              """.stripMargin)

            val train = assembler.transform(rawInput.na.fill(-999)).select("features", "longitude", "latitude")
            val model = booster.fit(train)
            val model_lat = booster_lat.fit(train)

            model.write.overwrite().save("xgBoost/trains_lon_models/trains_models_lon_" + k)
            model_lat.write.overwrite().save("xgBoost/trains_lat_models/trains_models_lat_" + k)

          } catch {
            case e: Exception =>
              e.printStackTrace()
          } finally {
            //当前线程调用此方法,则计数减一
            latch.countDown
          }
        }
      }
      service.execute(runnable)
    }

    try{
      latch.await()
    }catch {
      case e: Exception =>
        e.printStackTrace()
    }











    //    val rawTest: DataFrame = spark.sql(
    //      s"""
    //         |select * from mro_ns2_hive_db.tmp_mdt_hw_zx_njydwxt_xgboots_cs_partitions_test
    //          """.stripMargin)
    //    val test = assembler.transform(rawTest.na.fill(-999)).select("features", "longitude","latitude")
    //
    //
    //    val model_read =XGBoostRegressionModel.load("xgBoost/trains_lon_models/trains_models_lon_1")
    //    val model_lat_read = XGBoostRegressionModel.load("xgBoost/trains_lat_models/trains_models_lat_1")
    //    val prediction: DataFrame = model_read.transform(test)
    //    val prediction_lon: DataFrame = prediction.withColumnRenamed("prediction","prediction_longitude")
    //    val prediction_lat: DataFrame = model_lat_read.transform(prediction_lon)
    //    prediction_lat.createOrReplaceTempView("prediction_lat")
    //
    //    spark.sql(
    //      """
    //        |insert into  table  mro_ns2_hive_db.tmp_mro_msisdn_njy_xgboots_text_4
    //        |select
    //        |(6378137*2*ASIN(SQRT(POWER(SIN((latitude-prediction)*ACOS(-1)/360),2)+
    //        |COS(latitude*ACOS(-1)/180)*COS(prediction*ACOS(-1)/180)*POWER(SIN((longitude-prediction_longitude)*ACOS(-1)/360),2)))) as avg
    //        |from  prediction_lat
    //        |where longitude>100
    //      """.stripMargin)
    //  }
  spark.stop()
  }
}
/*

spark-submit \
--class com.nokia.zjbigdata.open.spark.impora.demoXgBoost24 \
--master yarn \
--deploy-mode cluster \
--driver-memory 8g \
--executor-memory 4g \
--executor-cores  3 \
--num-executors   150 \
--jars  xgboost.jar,xgboost4j-0.90.jar,xgboost4j_spark_2.3.2_0.90.jar,fastjson-1.2.47.jar,UDF.jar \
--name mro_demoXgBoost \
xgboost.jar
 */

数组par.foreach

val train_keys: Array[String] = rawInput_xl.rdd.map(_.get(0).toString).collect()
val text_keys: Array[String] = rawTest_cs.rdd.map(_.get(0).toString).collect()
//遍历train_keys数组
train_keys.par.foreach{
//这里输入你要执行的任务
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值