SparkSQL_JOIN 倾斜优化_1.broadcastJOIN 2.broadcast 3.数据膨胀 详解

  这是一个面试会遇到的问题,网上处理方法一大堆,但是讲清楚实现的并不多,也没什么例子。 这里对这个问题就具体的实现做个展示。

参考文章:

0.Spark SQL中实现Hive MapJoin

http://lxw1234.com/archives/2015/06/296.htm

1.【Spark篇】---Spark解决数据倾斜问题

https://www.cnblogs.com/LHWorldBlog/p/8506121.html

2.Spark 数据倾斜 join 调优

https://blog.csdn.net/a1043498776/article/details/77323561

 

首先,主流的做法有以下几种

1.broadcastJOIN (更类似于 Hive 中 Map JOIN)

2.broadcast 后,自己实现 join 相当于自己实现 broadcast join

3.数据膨胀,去前缀(操作复杂,不推荐)不知道为什么推崇这种做法?

 

出问题的代码

首先看一下代码

package com.spark.test.offline.skewed_data

import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

/**
  * Created by szh on 2020/6/5.
  */
object JOINSkewedData {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf
    sparkConf
      .setAppName("Union data test")
      .setMaster("local[4]")
      .set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
      //.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
      .set("spark.sql.shuffle.partitions", "3")
    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()

    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")


    val idArr = Array[Int](1, 2, 3)
    val userArr = new ArrayBuffer[(Int, String)]()
    val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")


    val threshold = 1000000

    for (i <- 1 to threshold) {

      var id = 10
      if (i < (threshold * 0.9)) {
        id = 1
      } else {
        id = i
      }
      val name = nameArr(Random.nextInt(5))

      userArr.+=((id, name))
    }

    val rddA = sparkContext
      .parallelize(userArr)
      .map(x => Row(x._1, x._2))

    val rddAStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
      )
    )

    val rddADF = spark.createDataFrame(rddA, rddAStruct)
    rddADF.createOrReplaceTempView("userA")

    //spark.sql("CACHE TABLE userA")

    //-----------------------------------------
    //---------------------------------------

    val arrList = new ArrayBuffer[(Int, Int)]


    for (i <- 1 to threshold) {
      var id = 10
      if (i < 5) {
        id = 1
      } else {
        id = i
      }
      val salary = Random.nextInt(100)

      arrList.+=((id, salary))
    }

    spark
      .createDataFrame(arrList).toDF("uid", "salary")
      .createOrReplaceTempView("listB")

    val resultDF = spark
      .sql("SELECT userA.uid, name, salary FROM userA JOIN listB ON userA.uid = listB.uid")


    resultDF.foreach(x => {
      val i = 1
    })

    resultDF.show()
    resultDF.explain(true)

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }

}

构建两个表

user_salary 表 中有两个字段  uid, salary 

user 表中有两个字段 uid, name  user 表中数据大量倾斜,大量uid =1

 

 

实际运行情况。

下图:所有作业切分

 下图:JOB0流程

 下图:JOB0 中 STAGE2 流程

 下图:JOB0 中 STAGE2 流程 中 Task 计算,

可以看到是有数据倾斜的

 下图:JOB 对应的SQL

 下图:JOB 的SQL0 执行流程

  下图:JOB 的SQL0 执行计划

 

 

1.broadcastJOIN (更类似于 Hive 中 Map JOIN)

  首先,我比较推荐 broadcastJOIN ,  相当于将数据广播到各个Executor 中 ,和 HIVE 的 Map JOIN 类似,相当于不产生数据 混洗的流程。

 主要的做法就是调整 spark.sql.autoBroadcastJoinThreshold 超过小表的数据大小,将小表进行广播,默认的大小是  10485760, 即 10M。

package com.spark.test.offline.skewed_data

import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

/**
  * Created by szh on 2020/6/5.
  */
object JOINSkewedData {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf
    sparkConf
      .setAppName("Union data test")
      .setMaster("local[4]")
      //.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
      .set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
      .set("spark.sql.shuffle.partitions", "3")
    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()

    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")


    val idArr = Array[Int](1, 2, 3)
    val userArr = new ArrayBuffer[(Int, String)]()
    val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")


    val threshold = 1000000

    for (i <- 1 to threshold) {

      var id = 10
      if (i < (threshold * 0.9)) {
        id = 1
      } else {
        id = i
      }
      val name = nameArr(Random.nextInt(5))

      userArr.+=((id, name))
    }

    val rddA = sparkContext
      .parallelize(userArr)
      .map(x => Row(x._1, x._2))

    val rddAStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
      )
    )

    val rddADF = spark.createDataFrame(rddA, rddAStruct)
    rddADF.createOrReplaceTempView("userA")

    //spark.sql("CACHE TABLE userA")

    //-----------------------------------------
    //---------------------------------------

    val arrList = new ArrayBuffer[(Int, Int)]


    for (i <- 1 to threshold) {
      var id = 10
      if (i < 5) {
        id = 1
      } else {
        id = i
      }
      val salary = Random.nextInt(100)

      arrList.+=((id, salary))
    }

    spark
      .createDataFrame(arrList).toDF("uid", "salary")
      .createOrReplaceTempView("listB")

    val resultDF = spark
      .sql("SELECT userA.uid, name, salary FROM userA JOIN listB ON userA.uid = listB.uid")


    resultDF.foreach(x => {
      val i = 1
    })

    resultDF.show()
    resultDF.explain(true)

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }

}

看一下执行流程

 

下图:SQL Job切分

下图 Job0 STAGE 切分 

下图:Job0流程

 下图:Stage0

下图:各个task执行时间 

可以看到消除了数据倾斜

下图 :  涉及到的SparkSQL

 

下图 : SQL 执行流程 

下图:SQL执行计划 

 

 

2.broadcast 后,join 相当于自己实现 broadcast join

代码:

package com.spark.test.offline.skewed_data

import org.apache.spark.SparkConf
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

/**
  * Created by szh on 2020/6/5.
  */
object JOINSkewedData2 {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf

    sparkConf
      .setAppName("JOINSkewedData")
      .set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
      //.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
      .set("spark.sql.shuffle.partitions", "3")

    if (args.length > 0 && args(0).equals("ide")) {
      sparkConf
        .setMaster("local[3]")
    }

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()


    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")
    sparkContext.setCheckpointDir("file:///D:/checkpoint/")


    val userArr = new ArrayBuffer[(Int, String)]()
    val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")

    val threshold = 1000000

    for (i <- 1 to threshold) {

      var id = 10
      if (i < (threshold * 0.9)) {
        id = 1
      } else {
        id = i
      }
      val name = nameArr(Random.nextInt(5))

      userArr.+=((id, name))
    }

    val rddA = sparkContext
      .parallelize(userArr)

    //spark.sql("CACHE TABLE userA")

    //-----------------------------------------
    //---------------------------------------

    val arrList = new ArrayBuffer[(Int, Int)]

    for (i <- 1 to (threshold * 0.1).toInt) {
      val id = i
      val salary = Random.nextInt(100)

      arrList.+=((id, salary))
    }


    val rddB = sparkContext
      .parallelize(arrList)

    val broadData: Broadcast[Array[(Int, Int)]] = sparkContext.broadcast(rddB.collect())


    import scala.util.control._

    val resultRdd = rddA
      .mapPartitions(arr => {

        val broadVal = broadData.value
        var rowArr = new ArrayBuffer[Row]()
        val broadMap = new mutable.HashMap[Int, Int]()

        for (tmpVal <- broadVal) {
          broadMap.+=((tmpVal._1, tmpVal._2))
        }

        while (arr.hasNext) {
          val x = arr.next
          if (broadMap.contains(x._1)) {
            rowArr.+=(Row(x._1, x._2, broadMap(x._1)))
          }
        }

        //TODO : 测试代码
        println(rowArr.size)

        rowArr.iterator
      })

    //        while (arr.hasNext) {
    //
    //          val x = arr.next
    //          val loop = new Breaks
    //          var rRow: Row = null
    //          //var rRow: Option[Row] = None
    //
    //          loop.breakable(
    //            for (tmpVal <- broadVal) {
    //              if (tmpVal._1 == x._1) {
    //                rRow = Row(tmpVal._1, x._2, tmpVal._2)
    //                //println(rRow)
    //                loop.break
    //              }
    //            }
    //          )
    //          if (rRow != null) {
    //            rowArr.+=(rRow)
    //            rRow = null
    //          }
    //        }
    //
    //        println(rowArr.size)
    //
    //        rowArr.iterator
    //      })
    //      .filter(x => {
    //        x match {
    //          case None => false
    //          case _ => true
    //        }
    //      })


    val resultStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
        , StructField("salary", IntegerType, nullable = true)
      )
    )

    spark
      .createDataFrame(resultRdd, resultStruct)
      .createOrReplaceTempView("resultB")

    val resultDF = spark
      .sql("SELECT uid, name, salary FROM resultB")


    resultDF.cache()
    resultDF.checkpoint()

    resultDF.foreach(x => {
      val i = 1
    })

    println(resultDF.count())

    resultDF.show()
    resultDF.explain(true)

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }

}

 

重要代码

    val broadData: Broadcast[Array[(Int, Int)]] = sparkContext.broadcast(rddB.collect())


    import scala.util.control._

    val resultRdd = rddA
      .mapPartitions(arr => {

        val broadVal = broadData.value
        var rowArr = new ArrayBuffer[Row]()
        val broadMap = new mutable.HashMap[Int, Int]()

        for (tmpVal <- broadVal) {
          broadMap.+=((tmpVal._1, tmpVal._2))
        }

        while (arr.hasNext) {
          val x = arr.next
          if (broadMap.contains(x._1)) {
            rowArr.+=(Row(x._1, x._2, broadMap(x._1)))
          }
        }

        //TODO : 测试代码
        println(rowArr.size)

        rowArr.iterator
      })

 

前提 : 小数据集合为维度表,并且维度表中有主键或者其他键作为唯一的标志。

Step1 那么将维度表进行广播

Step2 大表做MapPartition , 

Step3 MapPartition中, 获取 广播变量 维度表构建HashMap

Step4 MapPartition中, 遍历每个元素实现 JOIN / LEFT JOIN .... 的逻辑

 

注意:如果小表 JOIN 的 KEY 不是唯一的,建议对 KEY 先排序,再去实现JOIN 逻辑 , 较为复杂,这里不做演示!

 

执行效率 :

下图 :整体作业执行情况

下图 :JOB1 执行情况

下图 :JOB1 STAGE1 执行情况

可以看到已经消除了倾斜,没有发生 Shuffle !

 

 

3.数据膨胀,去前缀(操作复杂,不推荐)不知道为什么推崇这种做法?

 

具体的代码

package com.spark.test.offline.skewed_data

import org.apache.spark.SparkConf
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

/**
  * Created by szh on 2020/6/5.
  */
object JOINSkewedData3 {

  def main(args: Array[String]): Unit = {

    val sparkConf = new SparkConf

    sparkConf
      .setAppName("JOINSkewedData")
      .set("spark.sql.autoBroadcastJoinThreshold", "-1") //1M broadcastJOIN
      //.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
      //.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
      .set("spark.sql.shuffle.partitions", "3")

    if (args.length > 0 && args(0).equals("ide")) {
      sparkConf
        .setMaster("local[3]")
    }

    val spark = SparkSession.builder()
      .config(sparkConf)
      .getOrCreate()


    val sparkContext = spark.sparkContext
    sparkContext.setLogLevel("WARN")

    val userArr = new ArrayBuffer[(Int, String)]()
    val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")

    val threshold = 1000000

    for (i <- 1 to threshold) {

      var id = 10
      if (i < (threshold * 0.9)) {
        id = 1
      } else {
        id = i
      }
      val name = nameArr(Random.nextInt(5))

      userArr.+=((id, name))
    }

    val rddA = sparkContext
      .parallelize(userArr)
      .map(x => Row(x._1, x._2))

    val rddAStruct = StructType(
      Array(
        StructField("uid", IntegerType, nullable = true)
        , StructField("name", StringType, nullable = true)
      )
    )


    val fuction1 = (key: Int) => Random.nextInt(10)
    val random = udf(fuction1)

    val rddADF = spark.createDataFrame(rddA, rddAStruct)
    rddADF
      .withColumn("random", random(rddADF("uid")))
      .createOrReplaceTempView("userA")

    //spark.sql("CACHE TABLE userA")

    //-----------------------------------------
    //---------------------------------------

    val arrList = new ArrayBuffer[(Int, Int)]


    for (i <- 1 to threshold) {
      var id = 10
      if (i < 5) {
        id = 1
      } else {
        id = i
      }
      val salary = Random.nextInt(100)

      arrList.+=((id, salary))
    }


    val function2 = (key: Int) => "0,1,2,3,4,5,6,7,8,9"
    val randArr = udf(function2)

    val oldDF = spark
      .createDataFrame(arrList)
      .toDF("uid", "salary")

    oldDF
      .withColumn("rand_arr", randArr(oldDF("uid")))
      .createOrReplaceTempView("listB")

    val newDF = spark
      .sql("SELECT uid, salary, CAST(rand_key2 AS INT) AS rand_key FROM listB LATERAL VIEW EXPLODE(SPLIT(rand_arr, ',')) AS rand_key2")
      .createOrReplaceTempView("listB_new")

    //spark.sql("CACHE TABLE listB_new")



    val resultDF = spark.sql("SELECT userA.uid, userA.name, salary FROM userA JOIN listB_new ON  userA.uid = listB_new.uid AND userA.random = listB_new.rand_key ")

//    resultDF.cache()
//
//    resultDF.foreach(x => {
//      val i = 1
//    })
//
    println(resultDF.count())

//    resultDF.show()
    resultDF.explain(true)

    Thread.sleep(60 * 10 * 1000)

    sparkContext.stop()
  }

}

其中体现在2个地方,

对其中一个表利用 withColumn 增加一个随即列,

    val fuction1 = (key: Int) => Random.nextInt(10)
    val random = udf(fuction1)

    val rddADF = spark.createDataFrame(rddA, rddAStruct)
    rddADF
      .withColumn("random", random(rddADF("uid")))
      .createOrReplaceTempView("userA")

另一张表利用 explode 和 lateral view 等 udtf 函数,将原来的一列变为多列:

val function2 = (key: Int) => "0,1,2,3,4,5,6,7,8,9"
val randArr = udf(function2)

val oldDF = spark
  .createDataFrame(arrList)
  .toDF("uid", "salary")

oldDF
  .withColumn("rand_arr", randArr(oldDF("uid")))
  .createOrReplaceTempView("listB")

val newDF = spark
  .sql("SELECT uid, salary, CAST(rand_key2 AS INT) AS rand_key FROM listB LATERAL VIEW EXPLODE(SPLIT(rand_arr, ',')) AS rand_key2")
  .createOrReplaceTempView("listB_new")

当然具体的API 并不唯一,可以利用 mapPartitions 完成相同的膨胀操作 !!

 

 

 

 

下图: 生成的JOB 执行计划

下图: JOB0 的Stage 切分

 

下图:具体的连接的Shuffle task 运算耗时

可以看到消除了 数据倾斜的问题!!

 

下图:具体的SQL 执行流程,可以看到还是通过 SortMergeJOIN 完成的两张表的连接

 

 

 

 

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
NumPy中的broadcast_to函数是用来将数组广播到指定形状的函数。广播是一种在不进行复制的情况下,使用较小的数组来操作较大的数组的机制。该函数接受两个参数,第一个参数是要广播的数组,第二个参数是目标形状。广播的规则是,将较小的数组在某些维度上进行复制,使得两个数组可以进行元素级别的操作。最终,较小的数组将会被复制到与较大的数组具有相同形状的位置上。 下面是一个使用broadcast_to函数的示例代码: ```python import numpy as np a = np.array([[1, 2, 3]]) b = np.broadcast_to(a, (3, 3)) print("原数组 a:") print(a) print("调用 broadcast_to 函数之后的数组 b:") print(b) ``` 运行结果如下: ``` 原数组 a: [[1 2 3]] 调用 broadcast_to 函数之后的数组 b: [[1 2 3] [1 2 3] [1 2 3]] ``` 在这个示例中,我们创建了一个形状为(1, 3)的数组a,然后使用broadcast_to函数将其广播到形状为(3, 3)的数组b。由于广播的规则,数组a在第一个维度上被复制了3次,最终得到了与数组b具有相同形状的结果。 (出处: Python numpy.broadcast_to函数方法的使用) (出处: Python numpy.expand_dims函数方法的使用) (出处: Python numpy.squeeze函数方法的使用)<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [Python numpy.broadcast_to函数方法的使用](https://blog.csdn.net/a991361563/article/details/119977132)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [Python-Numpy多维数组--数组操作](https://blog.csdn.net/Odyssues_lee/article/details/85252366)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT0_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值