package com.sf.gis.scala.base.spark import java.util.Properties import com.alibaba.fastjson.JSONObject import com.sf.gis.java.base.dto.DBInfo import com.sf.gis.java.base.pojo.BasePojo import com.sf.gis.java.base.util.{CalPartitionUtil, ObjectUtil} import com.sf.gis.scala.base.util.JSONUtil import org.apache.log4j.Logger import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer /** * Created by 01374443 on 2020/7/27. */ object SparkWrite { @transient lazy val logger: Logger = Logger.getLogger(this.getClass) /** * 静态分区存储到hive表 * * @param sparkSession * @param dataRdd 需要保存的数据 * @param tableName 需要保存的目的hive库表名:(库名.表名) * @param partitionData 静态分区数据,(分区列名,分区值) * @param saveMode 存储模式,追加或者覆盖对应分区 * @return */ def save2HiveStatic(sparkSession: SparkSession, dataRdd: RDD[JSONObject], saveKeyName: Array[String], tableName: String, partitionData: Array[(String, String)], saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = { val schemaEle = new ArrayBuffer[StructField]() for (i <- saveKeyName.indices) { schemaEle.append(StructField(saveKeyName.apply(i), StringType, nullable = true)) } val schema = StructType(schemaEle.toArray) val rowRdd = dataRdd.map(obj => { val row = new ArrayBuffer[String] for (i <- saveKeyName.indices) { row.append(JSONUtil.getJsonValSingle(obj, saveKeyName.apply(i)).replaceAll("[\\r\\n\\t]", "")) } val ret = Row.fromSeq(row) ret }) var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2) val savePartition = CalPartitionUtil.getPartitionSize(rowDf, partitionCalSize) logger.error("存储分区数量:" + savePartition) if (partitionData != null && !partitionData.isEmpty) { for (i <- partitionData.indices) { println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2) rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2)) } } rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName) rowDf.unpersist() logger.error("存储完毕") } /** * * 单列存储 * * @param sparkSession * @param dataRdd * @param tableName * @param partitionData * @param saveMode */ def save2HiveStaticSingleColumn(sparkSession: SparkSession, dataRdd: RDD[JSONObject], tableName: String, partitionData: Array[(String, String)], saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = { val schemaEle = new ArrayBuffer[StructField]() schemaEle.append(StructField("total", StringType, nullable = true)) val schema = StructType(schemaEle.toArray) val rowRdd = dataRdd.map(obj => { val row = new ArrayBuffer[String] row.append(obj.toJSONString.replaceAll("[\\r\\n\\t]", "")) val ret = Row.fromSeq(row) ret }) var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2) val savePartition = CalPartitionUtil.getPartitionSize(rowDf, partitionCalSize) logger.error("存储分区数量:" + savePartition) if (partitionData != null && !partitionData.isEmpty) { for (i <- partitionData.indices) { println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2) rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2)) } } rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName) rowDf.unpersist() logger.error("存储完毕") } //tuple转array def tupleToList(p: Product) = p.productIterator.toArray /** * 静态分区,存储数组到hive表 * * @param sparkSession * @param dataRdd 需要保存的数据 * @param tableName 需要保存的目的hive库表名:(库名.表名) * @param partitionData 静态分区数据,(分区列名,分区值) * @param saveMode 存储模式,追加或者覆盖对应分区 * @return */ def saveArray2HiveStatic(sparkSession: SparkSession, dataRdd: RDD[Array[Any]], tableName: String, partitionData: Array[(String, String)], saveMode: SaveMode = SaveMode.Overwrite,partitionCalSize:Int=1000): Unit = { if (dataRdd.count() == 0) { logger.error("需要存储的数据量为空,直接返回") return } val firstSize = dataRdd.take(1).apply(0).length val schemaEle = new ArrayBuffer[StructField]() for (i <- 0 until firstSize) { schemaEle.append(StructField(i + "", StringType, nullable = true)) } val schema = StructType(schemaEle.toArray) val rowRdd = dataRdd.map(obj => { val row = new ArrayBuffer[String] for (i <- obj.indices) { row.append(obj(i).toString.replaceAll("[\\r\\n\\t]", "")) } val ret = Row.fromSeq(row) ret }) var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2) val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize) logger.error("存储分区数量:" + savePartition) if (partitionData != null && !partitionData.isEmpty) { for (i <- partitionData.indices) { println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2) rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2)) } } rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName) rowDf.unpersist() logger.error("存储完毕") } /** * 动态分区存储到hive表,按照注解的顺序 * * @param sparkSession * @param dataRdd 需要保存的数据,需要保存的数据都在对象的sort注解中 * @param tableName 需要保存的目的hive库表名:(库名.表名) * @param saveMode 存储模式,追加或者覆盖对应分区 * @return */ def savePojo2HiveStaticWithSort(sparkSession: SparkSession, dataRdd: RDD[BasePojo], tableName: String, partitionData: Array[(String, String)], saveMode: SaveMode,partitionCalSize:Int=1000): Unit = { val dataList = dataRdd.take(1) if (dataList.length == 0) { logger.error("数据量为空,直接返回") return } val dataItem = dataList(0) val saveColumnsIndex = ObjectUtil.getAnnotationColumnBySort(dataItem.getClass.getDeclaredFields) val saveColumns = ObjectUtil.getColumnsNameByFieldIndex(dataItem.getClass.getDeclaredFields, saveColumnsIndex) logger.error("存储的列名:" + saveColumns.mkString(",")) val schemaEle = new ArrayBuffer[StructField]() val keySave = saveColumns if (keySave == null) { for (i <- saveColumns.indices) { schemaEle.append(StructField(saveColumns(i), StringType, nullable = true)) } } val schema = StructType(schemaEle.toArray) val rowRdd = dataRdd.map(obj => { val dataArray = ObjectUtil.toArrayByColumns(obj, saveColumnsIndex) val ret = Row.fromSeq(dataArray) ret }) var rowDf = sparkSession.createDataFrame(rowRdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER_2) val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize) logger.error("存储分区数量:" + savePartition) if (partitionData != null && !partitionData.isEmpty) { for (i <- partitionData.indices) { println(partitionData.apply(i)._1 + "," + partitionData.apply(i)._2) rowDf = rowDf.withColumn(partitionData.apply(i)._1, lit(partitionData.apply(i)._2)) } } rowDf.repartition(savePartition).write.mode(saveMode).insertInto(tableName) rowDf.unpersist() logger.error("存储完毕") } /** * 动态分区存储到hive表 * * @param sparkSession * @param dataRdd 需要保存的数据 * @param tableName 需要保存的目的hive库表名:(库名.表名) * @param saveMode 存储模式,追加或者覆盖对应分区 * @param partitionColumns 动态分区列名,可变参数,按照数据中对应的列值,动态存储到对应的分区下 * @return */ def save2HiveDynamic(sparkSession: SparkSession, dataRdd: RDD[JSONObject], saveKeyName: Array[String], tableName: String, saveMode: SaveMode,partitionCalSize:Int, partitionColumns: String*): Unit = { import sparkSession.implicits._ val rowDf = dataRdd.map(obj => { val row = new ArrayBuffer[String] for (i <- saveKeyName.indices) { row.append(JSONUtil.getJsonValSingle(obj, saveKeyName.apply(i)).replaceAll("[\\r\\n\\t]", "")) } row }).toDF() val savePartition = CalPartitionUtil.getPartitionSize(rowDf,partitionCalSize) logger.error("存储分区数量:" + savePartition) rowDf.repartition(savePartition).write.mode(saveMode).partitionBy(partitionColumns: _*).insertInto(tableName) logger.error("存储完毕") } // /** // * 动态分区存储到hive表,按照变量定义的顺序 // * // * @param sparkSession // * @param dataRdd 需要保存的数据,需要保存的数据都在对象的sort注解中 // * @param tableName 需要保存的目的hive库表名:(库名.表名) // * @param saveMode 存储模式,追加或者覆盖对应分区 // * @return // */ /** * * TODO 还未调通 * * @param sparkSession * @param dbInfo * @param tableName * @param dataRdd */ def save2Mysql(sparkSession: SparkSession, dbInfo: DBInfo, tableName: String, dataRdd: RDD[JSONObject], saveColumns: Array[String], savePartitionNum: Int = 10): Unit = { val connectionProperties = new Properties() connectionProperties.setProperty("user", dbInfo.getUser); // 设置用户名 connectionProperties.setProperty("password", dbInfo.getPassword); // 设置 import sparkSession.implicits._ var rowDf = dataRdd.map(obj => { val row = new ArrayBuffer[String] for (i <- saveColumns.indices) { row.append(JSONUtil.getJsonValSingle(obj, saveColumns.apply(i)).replaceAll("[\\r\\n\\t]", "")) } row }).toDF() if (dataRdd.getNumPartitions != savePartitionNum) { rowDf = rowDf.repartition(savePartitionNum) } logger.error("写入并发数" + savePartitionNum) rowDf.write.mode(SaveMode.Append).jdbc(dbInfo.getUrl, tableName, connectionProperties) } def main(args: Array[String]): Unit = { val sparkSession = Spark.getSparkSession(this.getClass.getSimpleName.replace("$", ""), null, isLocal = true) val dbInfo = new DBInfo("") } }
spark写入工具类
最新推荐文章于 2023-03-04 13:33:30 发布