spark读工具类

package com.sf.gis.scala.base.spark

import org.apache.log4j.Logger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.storage.StorageLevel
import com.alibaba.fastjson.JSONObject
import com.sf.gis.java.base.dto.DBInfo
/**
  * Created by 01374443 on 2020/7/27.
  */
object SparkRead {
  @transient lazy val logger: Logger = Logger.getLogger(this.getClass)

  /**
    * 执行spark sql
    *
    * @param sparkSession
    * @param sql : sql语句
    * @return 数据数组集,列名数组
    */
  def readHiveAsRow(sparkSession: SparkSession, sql: String,repartitions:Int=0): (RDD[Row], Array[String]) = {
    val df = sparkSession.sql(sql)
    val columns = df.columns
    logger.error("列名:" + columns.mkString(","))
    var partition = repartitions
    val tmpPartitions = Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.instances", "0"))*  Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.cores", "0"))
    if(partition==0 ){
      partition =  tmpPartitions*4
    }
    val dataRdd = df.rdd.repartition(partition).map(obj => {
      obj
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("查询的总数据量:" + dataRdd.count())
    (dataRdd, columns)
  }

  /**
    * 执行spark sql
    *
    * @param sparkSession
    * @param sql : sql语句
    * @return 数据Json集,列名数组
    */
  def readHiveAsJson(sparkSession: SparkSession, sql: String,repartitions:Int=0): (RDD[JSONObject], Array[String]) = {
    val df = sparkSession.sql(sql)
    val columns = df.columns
    logger.error("列名:" + columns.mkString(","))
    var partition = repartitions
    val tmpPartitions = Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.instances", "0"))*  Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.cores", "0"))
    if(partition==0 ){
      partition = tmpPartitions*4
    }
    val dataRdd = df.rdd.repartition(partition).map(obj => {
      val jObj = new JSONObject()
      for (i <- columns.indices) {
        jObj.put(columns(i), obj.getString(i))
      }
      jObj
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("查询的总数据量:" + dataRdd.count())
    (dataRdd, columns)
  }

  /**
    * 读取mysql数据到数组
    * @param spark
    * @param dbInfo mysql库连接配置
    * @param sql mysql语句
    * @return 读取的结果,获取的列名
    */
  def readMysqlAsRow(spark:SparkSession, dbInfo:DBInfo, sql:String): (RDD[Row],Array[String])  = {
    val tableName = s"($sql) tmp_table"
    logger.error("sql:" + tableName)
    val connMap = Map("url" -> dbInfo.getUrl, "user" -> dbInfo.getUser, "password" -> dbInfo.getPassword, "dbtable" -> tableName)
    val df = spark.read.format("jdbc").options(connMap).load()
    val fields = df.columns
    val retRdd = df.rdd.map(row => {
      row
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("读取到的数量:"+retRdd.count())
    (retRdd,fields)
  }

  /**
    * 读取mysql数据到JSON
    * @param spark
    * @param dbInfo mysql库连接配置
    * @param sql mysql语句
    * @return 读取的结果,获取的列名
    */
  def readMysqlAsJson(spark:SparkSession, dbInfo:DBInfo, sql:String): (RDD[JSONObject],Array[String]) ={
    val tableName = s"($sql) tmp_table"
    logger.error("sql:" + tableName)
    val connMap = Map("url" -> dbInfo.getUrl, "user" -> dbInfo.getUser, "password" -> dbInfo.getPassword, "dbtable" -> tableName)
    val df = spark.read.format("jdbc").options(connMap).load()
    val fields = df.columns
    val fieldsBc = spark.sparkContext.broadcast(fields)
    val retRdd = df.rdd.map(row=>{
      val json = new JSONObject()
      val tmpFields = fieldsBc.value
      for (i <- tmpFields.indices) json.put(tmpFields(i), row(i))
      json
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("读取到的数量:"+retRdd.count())
    (retRdd,fields)
  }
  /* 读取csv文件,输出dataFrame
  * @param sparkSession
  * @param fileName: 文件路径,本地路径/hdfs路径
  * @return 输出数组型的RDD数据, 表头数组
  */
  def readCsvAsRow(sparkSession: SparkSession, filePath: String, encoding: String = "utf8", seq: String = "\t", header: Boolean = true): (RDD[Row], Array[String]) = {
    val df = sparkSession.read.option("header", header)
      .option("encoding", encoding).option("sep", seq).csv(filePath)
    val headers = df.columns
    val dataRdd = df.rdd.map(obj => {
      obj
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("文件读取数据量:" + dataRdd.count())
    (dataRdd, headers)
  }

  /* 读取csv文件,输出JSON
  * @param sparkSession
  * @param fileName: 文件路径,本地路径/hdfs路径
  * @return 输出JSONObject型的RDD数据, 表头数组
  */
  def readCsvAsJson(sparkSession: SparkSession, filePath: String, encoding: String = "utf8", seq: String = "\t", header: Boolean = true): (RDD[JSONObject], Array[String]) = {
    val df = sparkSession.read.option("header", header)
      .option("encoding", encoding).option("sep", seq).csv(filePath)
    val headers = df.columns
    logger.error("标题:" + headers.mkString(","))
    val headerBc = sparkSession.sparkContext.broadcast(headers)
    val dataRdd = df.rdd.map(obj => {
      val json = new JSONObject()
      val headerValue = headerBc.value
      for (i <- headerValue.indices) {
        json.put(headerValue.apply(i), obj.getString(i))
      }
      json
    }).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
    logger.error("文件读取数据量:" + dataRdd.count())
    (dataRdd, headers)
  }


}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值