spark写入工具类

package com.sf.gis.scala.base.spark

import java.util.Properties

import com.alibaba.fastjson.JSONObject
import com.sf.gis.java.base.dto.DBInfo
import com.sf.gis.java.base.pojo.BasePojo
import com.sf.gis.java.base.util.{CalPartitionUtil, ObjectUtil}
import com.sf.gis.scala.base.util.JSONUtil
import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable.ArrayBuffer


/**
  * Created by 01374443 on 2020/7/27.
  */
object SparkWrite {
  @transient lazy val logger: Logger = Logger.getLogger(this.getClass)


  /**
    * 静态分区存储到hive表
    *
    * @param sparkSession
    * @param dataRdd       需要保存的数据
    * @param tableName     需要保存的目的hive库表名:(库名.表名)
    * @param partitionData 静态分区数据,(分区列名,分区值)
    * @param saveMode      存储模式,追加或者覆盖对应分区
    * @return
    */
  def save2HiveStatic(sparkSession: SparkSession, dataRdd: RDD[JSONObject], saveKeyName: Array[String], tableName: String, partitionData: Array[(String, String)],
                      saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = {
    val schemaEle = new ArrayBuffer[StructField]()
    for (i <- saveKeyName.indices) {
      schemaEle.append(StructField(saveKeyName.apply(i), StringType, nullable = true))
    }
    val schema = StructType(schemaEle.toArray)
    val rowRdd = dataRdd.map(obj => {
      val row = new ArrayBuffer[String]
      for (i <- saveKeyName.indices) {
        row.append(JSONUtil.getJsonValSingle(obj, saveKeyName.apply(i)).replaceAll("[\\r\\n\\t]", ""))
      }
      val ret = Row.fromSeq(row)
      ret
    })
    var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    val savePartition = CalPartitionUtil.getPartitionSize(rowDf, partitionCalSize)
    logger.error("存储分区数量:" + savePartition)
    if (partitionData != null && !partitionData.isEmpty) {
      for (i <- partitionData.indices) {
        println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2)
        rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2))
      }
    }
    rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName)
    rowDf.unpersist()
    logger.error("存储完毕")
  }

  /** *
    * 单列存储
    *
    * @param sparkSession
    * @param dataRdd
    * @param tableName
    * @param partitionData
    * @param saveMode
    */
  def save2HiveStaticSingleColumn(sparkSession: SparkSession, dataRdd: RDD[JSONObject], tableName: String, partitionData: Array[(String, String)],
                                  saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = {
    val schemaEle = new ArrayBuffer[StructField]()
    schemaEle.append(StructField("total", StringType, nullable = true))

    val schema = StructType(schemaEle.toArray)
    val rowRdd = dataRdd.map(obj => {
      val row = new ArrayBuffer[String]
      row.append(obj.toJSONString.replaceAll("[\\r\\n\\t]", ""))
      val ret = Row.fromSeq(row)
      ret
    })
    var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    val savePartition = CalPartitionUtil.getPartitionSize(rowDf, partitionCalSize)
    logger.error("存储分区数量:" + savePartition)
    if (partitionData != null && !partitionData.isEmpty) {
      for (i <- partitionData.indices) {
        println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2)
        rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2))
      }
    }
    rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName)
    rowDf.unpersist()
    logger.error("存储完毕")
  }

  //tuple转array
  def tupleToList(p: Product) = p.productIterator.toArray

  /**
    * 静态分区,存储数组到hive表
    *
    * @param sparkSession
    * @param dataRdd       需要保存的数据
    * @param tableName     需要保存的目的hive库表名:(库名.表名)
    * @param partitionData 静态分区数据,(分区列名,分区值)
    * @param saveMode      存储模式,追加或者覆盖对应分区
    * @return
    */
  def saveArray2HiveStatic(sparkSession: SparkSession, dataRdd: RDD[Array[Any]], tableName: String, partitionData: Array[(String, String)],
                           saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = {
    if (dataRdd.count() == 0) {
      logger.error("需要存储的数据量为空,直接返回")
      return
    }
    val firstSize = dataRdd.take(1).apply(0).length
    val schemaEle = new ArrayBuffer[StructField]()
    for (i <- 0 until firstSize) {
      schemaEle.append(StructField(i + "", StringType, nullable = true))
    }
    val schema = StructType(schemaEle.toArray)
    val rowRdd = dataRdd.map(obj => {
      val row = new ArrayBuffer[String]
      for (i <- obj.indices) {
        row.append(obj(i).toString.replaceAll("[\\r\\n\\t]", ""))
      }
      val ret = Row.fromSeq(row)
      ret
    })
    var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize)
    logger.error("存储分区数量:" + savePartition)
    if (partitionData != null && !partitionData.isEmpty) {
      for (i <- partitionData.indices) {
        println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2)
        rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2))
      }
    }
    rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName)
    rowDf.unpersist()
    logger.error("存储完毕")
  }

  /**
    * 动态分区存储到hive表,按照注解的顺序
    *
    * @param sparkSession
    * @param dataRdd   需要保存的数据,需要保存的数据都在对象的sort注解中
    * @param tableName 需要保存的目的hive库表名:(库名.表名)
    * @param saveMode  存储模式,追加或者覆盖对应分区
    * @return
    */
  def savePojo2HiveStaticWithSort(sparkSession: SparkSession, dataRdd: RDD[BasePojo], tableName: String, partitionData: Array[(String, String)],
                                  saveMode: SaveMode,partitionCalSize:Int=1000): Unit = {
    val dataList = dataRdd.take(1)
    if (dataList.length == 0) {
      logger.error("数据量为空,直接返回")
      return
    }
    val dataItem = dataList(0)
    val saveColumnsIndex = ObjectUtil.getAnnotationColumnBySort(dataItem.getClass.getDeclaredFields)
    val saveColumns = ObjectUtil.getColumnsNameByFieldIndex(dataItem.getClass.getDeclaredFields, saveColumnsIndex)
    logger.error("存储的列名:" + saveColumns.mkString(","))
    val schemaEle = new ArrayBuffer[StructField]()
    val keySave = saveColumns
    if (keySave == null) {
      for (i <- saveColumns.indices) {
        schemaEle.append(StructField(saveColumns(i), StringType, nullable = true))
      }
    }
    val schema = StructType(schemaEle.toArray)
    val rowRdd = dataRdd.map(obj => {
      val dataArray = ObjectUtil.toArrayByColumns(obj, saveColumnsIndex)
      val ret = Row.fromSeq(dataArray)
      ret
    })
    var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize)
    logger.error("存储分区数量:" + savePartition)
    if (partitionData != null && !partitionData.isEmpty) {
      for (i <- partitionData.indices) {
        println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2)
        rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2))
      }
    }
    rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName)
    rowDf.unpersist()
    logger.error("存储完毕")
  }

  /**
    * 动态分区存储到hive表
    *
    * @param sparkSession
    * @param dataRdd          需要保存的数据
    * @param tableName        需要保存的目的hive库表名:(库名.表名)
    * @param saveMode         存储模式,追加或者覆盖对应分区
    * @param partitionColumns 动态分区列名,可变参数,按照数据中对应的列值,动态存储到对应的分区下
    * @return
    */
  def save2HiveDynamic(sparkSession: SparkSession, dataRdd: RDD[JSONObject], saveKeyName: Array[String], tableName: String, saveMode: SaveMode,partitionCalSize:Int, partitionColumns: String*): Unit = {
    import sparkSession.implicits._
    val rowDf = dataRdd.map(obj => {
      val row = new ArrayBuffer[String]
      for (i <- saveKeyName.indices) {
        row.append(JSONUtil.getJsonValSingle(obj, saveKeyName.apply(i)).replaceAll("[\\r\\n\\t]", ""))
      }
      row
    }).toDF()
    val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize)
    logger.error("存储分区数量:" + savePartition)
    rowDf.repartition(savePartition).write.mode(saveMode).partitionBy(partitionColumns: _*).insertInto(tableName)
    logger.error("存储完毕")
  }

  //  /**
  //    * 动态分区存储到hive表,按照变量定义的顺序
  //    *
  //    * @param sparkSession
  //    * @param dataRdd          需要保存的数据,需要保存的数据都在对象的sort注解中
  //    * @param tableName        需要保存的目的hive库表名:(库名.表名)
  //    * @param saveMode         存储模式,追加或者覆盖对应分区
  //    * @return
  //    */

  /** *
    * TODO 还未调通
    *
    * @param sparkSession
    * @param dbInfo
    * @param tableName
    * @param dataRdd
    */
  def save2Mysql(sparkSession: SparkSession, dbInfo: DBInfo, tableName: String, dataRdd: RDD[JSONObject],
                 saveColumns: Array[String], savePartitionNum: Int = 10): Unit = {
    val connectionProperties = new Properties()
    connectionProperties.setProperty("user", dbInfo.getUser); // 设置用户名
    connectionProperties.setProperty("password", dbInfo.getPassword); // 设置
    import sparkSession.implicits._
    var rowDf = dataRdd.map(obj => {
      val row = new ArrayBuffer[String]
      for (i <- saveColumns.indices) {
        row.append(JSONUtil.getJsonValSingle(obj, saveColumns.apply(i)).replaceAll("[\\r\\n\\t]", ""))
      }
      row
    }).toDF()
    if (dataRdd.getNumPartitions != savePartitionNum) {
      rowDf = rowDf.repartition(savePartitionNum)
    }
    logger.error("写入并发数" + savePartitionNum)
    rowDf.write.mode(SaveMode.Append).jdbc(dbInfo.getUrl, tableName, connectionProperties)
  }

  def main(args: Array[String]): Unit = {
    val sparkSession = Spark.getSparkSession(this.getClass.getSimpleName.replace("$", ""), null, isLocal = true)
    val dbInfo = new DBInfo("")
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值