sparksql 多分区读RDS的两种方式(mysql 为例)

import java.text.SimpleDateFormat
import java.util.Properties

import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer

object MultiplePartitionsMysql {
  var spark = SparkSession
    .builder()
    .appName(this.getClass.getSimpleName.filter(!_.equals("$")))
    //.enableHiveSupport()
    .master("local[*]")
    .config("hive.exec.dynamic.partition", "true")
    .config("hive.exec.dynamic.partition.mode", "nonstrict")
    .config("hive.exec.max.dynamic.partitions", "10000")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .getOrCreate()
  val youqun = new Properties()
  youqun.put("user", "root")
  youqun.put("password", "123456")
  youqun.put("driver", "com.mysql.jdbc.Driver")
  youqun.put("url", "jdbc:mysql://localhost:3306/wang?characterEncoding=utf8&useSSL=false")
  val partitions = 3

  def main(args: Array[String]): Unit = {
    spark.sparkContext.setLogLevel("ERROR")
    val resultframe = readTables(youqun.getProperty("url"), "csv", youqun, 3, "load_date", "date").cache()
    //val resultframe = readTables(youqun.getProperty("url"), "csv", youqun, 4, "sequence", "long")
    //resultframe.show(10, false)
    println(resultframe.count())
    println(resultframe.rdd.getNumPartitions)
    resultframe.rdd.glom().foreach(part => println(part.getClass.getName + "=======" + part.length))
  }

  /**
   * @param url        指定url
   * @param table      指定表名
   * @param properties 连接rds设置参数
   * @param partitions 自定义分区个数
   * @param column     指定多分区划分字段
   * @param columntype 多分区划分字段类型
   * @return
   */
  def readTables(url: String, table: String, properties: Properties, partitions: Int, column: String, columntype: String) = {
    var dataFrame = spark.emptyDataFrame
    //根据 指定分区字段类型判断 多分区读取数据方式
    if (columntype.toLowerCase() == "long") {
      val array = getMaxMin(table, properties, column)
      val minNum = array(0).toLong
      val maxNum = array(1).toLong
      dataFrame = spark.read.jdbc(url, table, column, minNum, maxNum, partitions, properties)
      //如果分区字段是时间格式,根据字段划分分区区间
    } else if (columntype.toLowerCase() == "date") {
      val array = getMaxMin(table, properties, column)
      val arraypartition = generateArray(array, partitions, column)
      dataFrame = spark.read.jdbc(url, table, arraypartition, properties)
    }
    dataFrame
  }

  def getMaxMin(table: String, properties: Properties, column: String) = {
    val arrays = ArrayBuffer[String]()
    val array = spark.read.jdbc(youqun.getProperty("url"), table, youqun).selectExpr(s"min(${column}) as minNum", s"max(${column}) as maxNum").collect()
    if (array.length == 1) {
      arrays.append((array(0)(0).toString))
      arrays.append((array(0)(1).toString))
    }
    arrays.toArray
  }

  //根据最小最大值时间范围,按照指定分区个数切分成时间分区
  def generateArray(minmaxNum: Array[String], partition: Int, colum: String): Array[String] = {
    val array = ArrayBuffer[(String, String)]()
    var resultArray = Array[String]()
    //根据常见的时间格式进行调整
    if (minmaxNum(0).contains("-")) {
      val dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
      var minTime = dateFormat.parse(minmaxNum(0)).getTime()
      val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
      val subNum = (maxTime - minTime) / partition.toLong
      var midNum = minTime
      for (i <- 0 to partition - 1) {
        minTime = midNum
        midNum = midNum + subNum
        if (i == partition - 1) {
          array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
        } else {
          array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
        }
      }
    } else if (minmaxNum(0).contains("/")) {
      val dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
      var minTime = dateFormat.parse(minmaxNum(0)).getTime()
      val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
      val subNum = (maxTime - minTime) / partition.toLong
      var midNum = minTime
      for (i <- 0 to partition - 1) {
        minTime = midNum
        midNum = midNum + subNum
        if (i == partition - 1) {
          array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
        } else {
          array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
        }
      }
    } else {
      val dateFormat = new SimpleDateFormat("yyyyMMdd HH:mm:ss")
      var minTime = dateFormat.parse(minmaxNum(0)).getTime()
      val maxTime = dateFormat.parse(minmaxNum(1)).getTime()
      val subNum = (maxTime - minTime) / partition.toLong
      var midNum = minTime
      for (i <- 0 to partition - 1) {
        minTime = midNum
        midNum = midNum + subNum
        if (i == partition - 1) {
          array.append(dateFormat.format(minTime) -> dateFormat.format(maxTime))
        } else {
          array.append(dateFormat.format(minTime) -> dateFormat.format(midNum))
        }
      }
    }
    //根据时间划分区间,并且左闭右开,避免数据时间点重叠
    resultArray = array.toArray.map {
      case (start, end) => s"'${start}'<= ${colum} and ${colum} < '${end}'"
    }
    //将最后一个时间区间范围改为闭区间
    resultArray.update(resultArray.size - 1, resultArray.last.replaceAll(s"${colum} <", s"${colum} <="))
    println(resultArray.mkString(" " + "\n"))
    resultArray
  }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值