SparkUtils工具类

package com.xxx.sparktest.utils

import java.sql.DriverManager
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Calendar, Map, Properties}

import com.alibaba.fastjson.{JSONArray, JSONObject}
import com.fengtu.sparktest.utils2.Utils
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.log4j.Logger

object SparkUtils {

  @transient lazy val logger: Logger = Logger.getLogger(this.getClass)



  //aggregateByKey的柯里化函数,将JSONObject聚合成List[JSONObject]
  val seqOp = (a: List[JSONObject], b: JSONObject) => a.size match {
    case 0 => List(b)
    case _ => b::a
  }

  val combOp = (a: List[JSONObject], b: List[JSONObject]) => {
    a ::: b
  }

  val seqOpRow = (a: List[Row], b: Row) => a.size match {
    case 0 => List(b)
    case _ => b::a
  }

  val combOpRow = (a: List[Row], b: List[Row]) => {
    a ::: b
  }

  //进行post请求,失败后重试一次
  def doPost(url:String,reqJson:JSONObject,logger:Logger) = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.post(url, reqJson, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJson")
        try {
          resbonseBody = Utils.post(url, reqJson, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJson")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }
    resbonseBody
  }

  //进行post表单请求,失败后重试一次
  def doPostForm(url: String, map: Map[String, String], logger: Logger):String = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.postForm(url, map, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n${map.toString}")
        try {
          resbonseBody = Utils.postForm(url, map, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n${map.toString}")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }

    resbonseBody
  }


  /**
    * 进行post请求,失败后重试一次
    * @param url
    * @param reqJsonArr
    * @param logger
    * @return
    */
  def doPostArr(url:String, reqJsonArr:JSONArray, logger:Logger) = {
    var resbonseBody = "{}"

    try {
      resbonseBody = Utils.post(url, reqJsonArr, "utf-8")
    } catch {
      case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJsonArr")
        try {
          resbonseBody = Utils.post(url, reqJsonArr, "utf-8")
        } catch {
          case e: Exception => logger.error(e + s"\n>>>发送post请求失败<<<\n$reqJsonArr")
            resbonseBody = "发送post请求失败:" + e.toString
        }
    }
    resbonseBody
  }

  def df2Hive(spark:SparkSession,rdd: RDD[Row],schema:StructType,saveMode:String,descTable:String,partitionSchm:String,incDay:String,logger: Logger): Unit = {
    logger.error(s"写入hive ${descTable}中...")
    val df = spark.sqlContext.createDataFrame(rdd,schema)
    //写入前删除分区数据
    val dropSql = s"alter table $descTable drop if exists partition($partitionSchm='$incDay')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(partitionSchm).saveAsTable(descTable)

    logger.error(s"写入分区${incDay}成功")
  }

  def df2HivePs(spark:SparkSession,rdd: RDD[Row],schema:StructType,saveMode:String,descTable:String,incDay:String,region:String,logger: Logger,partitionDay:String,partitionRegion:String): Unit = {
    logger.error(s"写入hive ${descTable}中...")
    val df = spark.sqlContext.createDataFrame(rdd,schema)
    //写入前删除分区数据
    val dropSql = s"alter table $descTable drop if exists partition($incDay='$partitionDay',$region='$partitionRegion')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(incDay,region).saveAsTable(descTable)

    logger.error(s"写入分区 $partitionDay,$partitionRegion 成功")
  }

  def df2Hive(spark:SparkSession, df:DataFrame,saveMode:String,descTable:String,partitionSchm:String,incDay:String,logger: Logger): Unit = {
    logger.error(s"写入hive ${descTable}中...")

    //写入前删除分区数据
    val dropSql = s"alter table ${descTable} drop if exists partition($partitionSchm='$incDay')"
    logger.error(dropSql)
    spark.sql(dropSql)

    df.write.format("hive").mode(saveMode).partitionBy(partitionSchm).saveAsTable(descTable)

    logger.error(s"写入分区${partitionSchm}成功")
  }


  def getRowToJson(sourDf:DataFrame,parNum:Int = 200) ={

    val colList = sourDf.columns

    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
      val jsonObj = new JSONObject()
      for (columns <- colList) {
        jsonObj.put(columns,obj.getAs[String](columns))
      }
      jsonObj
    }).persist(StorageLevel.DISK_ONLY)

    println(s"共获取数据:${sourRdd.count()}")
    //sourRdd.take(2).foreach(println(_))
    sourDf.unpersist()

    sourRdd
  }



  def getRowToJson(sourDf:DataFrame,persistLevel:String) ={
    val parNum = 200
    var storageLevel = StorageLevel.MEMORY_AND_DISK_SER

    if (persistLevel.equals("DISK_ONLY")){
      storageLevel =  StorageLevel.DISK_ONLY
    }


    val colList = sourDf.columns

    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
      val jsonObj = new JSONObject()
      for (columns <- colList) {
        jsonObj.put(columns,obj.getAs[String](columns))
      }
      jsonObj
    }).persist(storageLevel)

    println(s"共获取数据:${sourRdd.count()}")
    //sourRdd.take(2).foreach(println(_))
    sourDf.unpersist()

    sourRdd
  }

  def getRowToJson(sourDf:DataFrame,persistLevel:String,parNum:Int) ={
    var storageLevel = StorageLevel.MEMORY_AND_DISK_SER

    if (persistLevel.equals("DISK_ONLY")){
      storageLevel =  StorageLevel.DISK_ONLY
    }


    val colList = sourDf.columns

    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
      val jsonObj = new JSONObject()
      for (columns <- colList) {
        jsonObj.put(columns,obj.getAs[String](columns))
      }
      jsonObj
    }).persist(storageLevel)

    println(s"共获取数据:${sourRdd.count()}")
    //sourRdd.take(2).foreach(println(_))
    sourDf.unpersist()

    sourRdd
  }



  /**
    *
    * @param spark
    * @param rdd
    * @param schema
    * @param user
    * @param password
    * @param saveMode
    * @param jdbcUrl
    * @param tblName
    * @param incDay
    * @param logger
    * @param statdate
    */
  def df2Mysql(spark: SparkSession,rdd: RDD[Row],schema:StructType,user:String,password:String,
               saveMode:String,jdbcUrl:String,tblName:String,incDay: String, logger: Logger,statdate:String = "statdate"): Unit = {

    val delSql = String.format(s"delete from $tblName where %s='%s'",statdate,incDay)
    logger.error(">>>保存之前,删除当天的数据:" + delSql)
    Class.forName("com.mysql.jdbc.Driver");
    val conn = DriverManager.getConnection(jdbcUrl,user,password)
    DbUtils.executeSql(conn, delSql)
    //conn.close()

    //创建临时表
    val tmpTbl = spark.sqlContext.createDataFrame(rdd,schema).persist()

    //创建Properties存储数据库相关属性
    val prop = new Properties()
    prop.setProperty("user", user)
    prop.setProperty("password", password)
    prop.setProperty("charset","utf8mb4")

    logger.error(s"正在写入mysql:${tblName}")

    //将数据追加到数据库
    tmpTbl.write.mode(SaveMode.Append).jdbc(jdbcUrl,tblName,prop)

    DbUtils.querySql(conn,String.format(s"select count(1) from $tblName where %s='%s'",statdate,incDay))
    conn.close()
    tmpTbl.unpersist()
  }


  def df2MysqlInc(spark: SparkSession,rdd: RDD[Row],schema:StructType,user:String,password:String,
               saveMode:String,jdbcUrl:String,tblName:String,incDay: String, logger: Logger,statdate:String = "inc_day"): Unit = {

    val delSql = String.format(s"delete from $tblName where %s='%s'",statdate,incDay)
    logger.error(">>>保存之前,删除当天的数据:" + delSql)
    Class.forName("com.mysql.jdbc.Driver")


    val conn = DriverManager.getConnection(jdbcUrl,user,password)
    DbUtils.executeSql(conn, delSql)
    DbUtils.executeSql(conn, "Set Names 'utf8mb4'")
    DbUtils.executeSql(conn, "set character_set_client = utf8mb4")

    //conn.close()

    //创建临时表
    val tmpTbl = spark.sqlContext.createDataFrame(rdd,schema).persist()

    //创建Properties存储数据库相关属性
    val prop = new Properties()
    prop.setProperty("user", user)
    prop.setProperty("password", password)
    prop.setProperty("charset","utf8mb4")

    logger.error(s"正在写入mysql:${tblName}")

    //将数据追加到数据库
    tmpTbl.write.mode(SaveMode.Append).jdbc(jdbcUrl,tblName,prop)

    DbUtils.querySql(conn,String.format(s"select count(1) from $tblName where %s='%s'",statdate,incDay))
    logger.error("查询数据量为:" + tmpTbl.count())

    conn.close()
    tmpTbl.unpersist()
    logger.error("写入" + tblName + "成功")


  }

  def df2MysqlIncDf(spark: SparkSession,df:DataFrame,schema:StructType,user:String,password:String,
                  saveMode:String,jdbcUrl:String,tblName:String,incDay: String, logger: Logger,statdate:String = "inc_day"): Unit = {

    val delSql = String.format(s"delete from $tblName where %s='%s'",statdate,incDay)
    logger.error(">>>保存之前,删除当天的数据:" + delSql)
    Class.forName("com.mysql.jdbc.Driver")


    val conn = DriverManager.getConnection(jdbcUrl,user,password)
    DbUtils.executeSql(conn, delSql)
    DbUtils.executeSql(conn, "Set Names 'utf8mb4'")
    DbUtils.executeSql(conn, "set character_set_client = utf8mb4")

    //conn.close()

    //创建临时表
    val tmpTbl = df.persist()

    //创建Properties存储数据库相关属性
    val prop = new Properties()
    prop.setProperty("user", user)
    prop.setProperty("password", password)
    prop.setProperty("charset","utf8mb4")

    logger.error(s"正在写入mysql:${tblName}")

    //将数据追加到数据库
    tmpTbl.write.mode(SaveMode.Append).jdbc(jdbcUrl,tblName,prop)

    DbUtils.querySql(conn,String.format(s"select count(1) from $tblName where %s='%s'",statdate,incDay))

    logger.error("查询数据量为:" + tmpTbl.count())

    conn.close()
    tmpTbl.unpersist()
    logger.error("写入" + tblName + "成功")


  }

  /**
    * 随机散列数据后做聚合
    *
    * @param obj     输入数据
    * @param hashNum 散列倍数,将随机一定范围内的随机值作为散列前缀
    */
  def groupByKeyTwoStep(obj: RDD[(String, Object)], hashNum: Int): Unit = {
    // 先添加随机值散列,第一次聚合
    val hashData = obj.map(obj => {
      val hashPrefix = new Random().nextInt(hashNum)
      ((hashPrefix, obj._1), obj._2)
    }).groupByKey().map(obj => {
      (obj._1._2, obj._2.toArray)
    })
    //再去除散列进行第二次聚合
    hashData.groupByKey().map(obj => {
      val key = obj._1
      val valueIterator = obj._2.iterator
      val ret = new ArrayBuffer[Object]
      while (valueIterator.hasNext) {
        val tmpArray = valueIterator.next()
        ret.appendAll(tmpArray)
      }
      (key, ret)
    })
  }



  def akLimitMultiThreadRdd[T](body: => RDD[T])(fun: T => T)(limitMin: Int,partitionCnt: Int) ={

    val resRdd = body.mapPartitions(iter => {
      val partitionLimitMinu = limitMin * 0.9 / partitionCnt
      val lastMin = new AtomicInteger(Calendar.getInstance().get(Calendar.MINUTE))
      val timeInt = new AtomicInteger(0)
      val partitionsCount = new AtomicInteger(0)
      for (obj <- iter) yield {

        if (partitionsCount.incrementAndGet() % 10000 == 0) {
          logger.error(partitionsCount)
        }
        val second = Calendar.getInstance().get(Calendar.SECOND)
        val cur = Calendar.getInstance().get(Calendar.MINUTE)
        if (cur == lastMin.get()) {
          if (timeInt.incrementAndGet() >= partitionLimitMinu) {
            logger.error("秒数:" + cur + ",次数:" + timeInt + ",总数:" + partitionsCount.get())
            Thread.sleep((60 - second) * 1000)
          }
        } else {
          //不能精细控制,set存在并发问题
          timeInt.set(1)
          lastMin.set(cur)
        }
        fun
      }
    })

    resRdd

  }



  def akLimitMultiThreadDetail(partitionLimitMinu: Int, partitionsCount: AtomicInteger, lastMin: AtomicInteger,
                               timeInt: AtomicInteger, fun: (String, JSONObject, Map[String, String]) => JSONObject,
                               retList: LinkedBlockingQueue[JSONObject], ak: String,
                               obj: JSONObject, keyMap: Map[String, String]): Unit = {
    if (partitionsCount.incrementAndGet() % 100 == 0) {
      logger.error(partitionsCount)
    }
    val second = Calendar.getInstance().get(Calendar.SECOND)
    val cur = Calendar.getInstance().get(Calendar.MINUTE)
    if (cur == lastMin.get()) {
      if (timeInt.incrementAndGet() >= partitionLimitMinu) {
        logger.error("分钟数:" + second + ",次数:" + timeInt + ",总数:" + partitionsCount.get())
        Thread.sleep((60 - second) * 1000)
      }
    } else {
      //不能精细控制,set存在并发问题
      timeInt.set(1)
      lastMin.set(cur)
    }
    retList.add(fun(ak, obj, keyMap))
  }



//  def getRowToJson(spark:SparkSession,querySql:String,parNum:Int=200 ) ={
//    val sourDf = spark.sql(querySql).persist(StorageLevel.DISK_ONLY)
//
//    val colList = sourDf.columns
//
//    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
//      val jsonObj = new JSONObject()
//      for (columns <- colList) {
//        jsonObj.put(columns,obj.getAs[String](columns))
//      }
//      jsonObj
//    }).persist(StorageLevel.DISK_ONLY)
//
//    println(s"共获取数据:${sourRdd.count()}")
//
//    sourDf.unpersist()
//
//    sourRdd
//  }

//
//  def getRowToJsonTurp( spark:SparkSession,querySql:String,key:String,parNum:Int=200 ) ={
//    val sourDf = spark.sql(querySql).persist(StorageLevel.DISK_ONLY)
//
//    val colList = sourDf.columns
//
//    val sourRdd = sourDf.rdd.repartition(parNum).map( obj => {
//      val jsonObj = new JSONObject()
//      for (columns <- colList) {
//        jsonObj.put(columns,obj.getAs[String](columns))
//      }
//      (jsonObj.getString(key),jsonObj)
//    }).persist(StorageLevel.DISK_ONLY)
//
//    sourRdd
//  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值