package com.sf.gis.scala.base.spark import org.apache.log4j.Logger import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.storage.StorageLevel import com.alibaba.fastjson.JSONObject import com.sf.gis.java.base.dto.DBInfo /** * Created by 01374443 on 2020/7/27. */ object SparkRead { @transient lazy val logger: Logger = Logger.getLogger(this.getClass) /** * 执行spark sql * * @param sparkSession * @param sql : sql语句 * @return 数据数组集,列名数组 */ def readHiveAsRow(sparkSession: SparkSession, sql: String,repartitions:Int=0): (RDD[Row], Array[String]) = { val df = sparkSession.sql(sql) val columns = df.columns logger.error("列名:" + columns.mkString(",")) var partition = repartitions val tmpPartitions = Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.instances", "0"))* Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.cores", "0")) if(partition==0 ){ partition = tmpPartitions*4 } val dataRdd = df.rdd.repartition(partition).map(obj => { obj }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("查询的总数据量:" + dataRdd.count()) (dataRdd, columns) } /** * 执行spark sql * * @param sparkSession * @param sql : sql语句 * @return 数据Json集,列名数组 */ def readHiveAsJson(sparkSession: SparkSession, sql: String,repartitions:Int=0): (RDD[JSONObject], Array[String]) = { val df = sparkSession.sql(sql) val columns = df.columns logger.error("列名:" + columns.mkString(",")) var partition = repartitions val tmpPartitions = Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.instances", "0"))* Integer.valueOf(sparkSession.sparkContext.getConf.get("spark.executor.cores", "0")) if(partition==0 ){ partition = tmpPartitions*4 } val dataRdd = df.rdd.repartition(partition).map(obj => { val jObj = new JSONObject() for (i <- columns.indices) { jObj.put(columns(i), obj.getString(i)) } jObj }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("查询的总数据量:" + dataRdd.count()) (dataRdd, columns) } /** * 读取mysql数据到数组 * @param spark * @param dbInfo mysql库连接配置 * @param sql mysql语句 * @return 读取的结果,获取的列名 */ def readMysqlAsRow(spark:SparkSession, dbInfo:DBInfo, sql:String): (RDD[Row],Array[String]) = { val tableName = s"($sql) tmp_table" logger.error("sql:" + tableName) val connMap = Map("url" -> dbInfo.getUrl, "user" -> dbInfo.getUser, "password" -> dbInfo.getPassword, "dbtable" -> tableName) val df = spark.read.format("jdbc").options(connMap).load() val fields = df.columns val retRdd = df.rdd.map(row => { row }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("读取到的数量:"+retRdd.count()) (retRdd,fields) } /** * 读取mysql数据到JSON * @param spark * @param dbInfo mysql库连接配置 * @param sql mysql语句 * @return 读取的结果,获取的列名 */ def readMysqlAsJson(spark:SparkSession, dbInfo:DBInfo, sql:String): (RDD[JSONObject],Array[String]) ={ val tableName = s"($sql) tmp_table" logger.error("sql:" + tableName) val connMap = Map("url" -> dbInfo.getUrl, "user" -> dbInfo.getUser, "password" -> dbInfo.getPassword, "dbtable" -> tableName) val df = spark.read.format("jdbc").options(connMap).load() val fields = df.columns val fieldsBc = spark.sparkContext.broadcast(fields) val retRdd = df.rdd.map(row=>{ val json = new JSONObject() val tmpFields = fieldsBc.value for (i <- tmpFields.indices) json.put(tmpFields(i), row(i)) json }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("读取到的数量:"+retRdd.count()) (retRdd,fields) } /* 读取csv文件,输出dataFrame * @param sparkSession * @param fileName: 文件路径,本地路径/hdfs路径 * @return 输出数组型的RDD数据, 表头数组 */ def readCsvAsRow(sparkSession: SparkSession, filePath: String, encoding: String = "utf8", seq: String = "\t", header: Boolean = true): (RDD[Row], Array[String]) = { val df = sparkSession.read.option("header", header) .option("encoding", encoding).option("sep", seq).csv(filePath) val headers = df.columns val dataRdd = df.rdd.map(obj => { obj }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("文件读取数据量:" + dataRdd.count()) (dataRdd, headers) } /* 读取csv文件,输出JSON * @param sparkSession * @param fileName: 文件路径,本地路径/hdfs路径 * @return 输出JSONObject型的RDD数据, 表头数组 */ def readCsvAsJson(sparkSession: SparkSession, filePath: String, encoding: String = "utf8", seq: String = "\t", header: Boolean = true): (RDD[JSONObject], Array[String]) = { val df = sparkSession.read.option("header", header) .option("encoding", encoding).option("sep", seq).csv(filePath) val headers = df.columns logger.error("标题:" + headers.mkString(",")) val headerBc = sparkSession.sparkContext.broadcast(headers) val dataRdd = df.rdd.map(obj => { val json = new JSONObject() val headerValue = headerBc.value for (i <- headerValue.indices) { json.put(headerValue.apply(i), obj.getString(i)) } json }).persist(StorageLevel.MEMORY_AND_DISK_SER_2) logger.error("文件读取数据量:" + dataRdd.count()) (dataRdd, headers) } }
spark读工具类
最新推荐文章于 2021-08-17 15:05:47 发布