这是一个面试会遇到的问题,网上处理方法一大堆,但是讲清楚实现的并不多,也没什么例子。 这里对这个问题就具体的实现做个展示。
参考文章:
http://lxw1234.com/archives/2015/06/296.htm
https://www.cnblogs.com/LHWorldBlog/p/8506121.html
2.Spark 数据倾斜 join 调优
https://blog.csdn.net/a1043498776/article/details/77323561
首先,主流的做法有以下几种
1.broadcastJOIN (更类似于 Hive 中 Map JOIN)
2.broadcast 后,自己实现 join 相当于自己实现 broadcast join
3.数据膨胀,去前缀(操作复杂,不推荐)不知道为什么推崇这种做法?
出问题的代码
首先看一下代码
package com.spark.test.offline.skewed_data
import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
/**
* Created by szh on 2020/6/5.
*/
object JOINSkewedData {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf
sparkConf
.setAppName("Union data test")
.setMaster("local[4]")
.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
//.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
.set("spark.sql.shuffle.partitions", "3")
val spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate()
val sparkContext = spark.sparkContext
sparkContext.setLogLevel("WARN")
val idArr = Array[Int](1, 2, 3)
val userArr = new ArrayBuffer[(Int, String)]()
val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")
val threshold = 1000000
for (i <- 1 to threshold) {
var id = 10
if (i < (threshold * 0.9)) {
id = 1
} else {
id = i
}
val name = nameArr(Random.nextInt(5))
userArr.+=((id, name))
}
val rddA = sparkContext
.parallelize(userArr)
.map(x => Row(x._1, x._2))
val rddAStruct = StructType(
Array(
StructField("uid", IntegerType, nullable = true)
, StructField("name", StringType, nullable = true)
)
)
val rddADF = spark.createDataFrame(rddA, rddAStruct)
rddADF.createOrReplaceTempView("userA")
//spark.sql("CACHE TABLE userA")
//-----------------------------------------
//---------------------------------------
val arrList = new ArrayBuffer[(Int, Int)]
for (i <- 1 to threshold) {
var id = 10
if (i < 5) {
id = 1
} else {
id = i
}
val salary = Random.nextInt(100)
arrList.+=((id, salary))
}
spark
.createDataFrame(arrList).toDF("uid", "salary")
.createOrReplaceTempView("listB")
val resultDF = spark
.sql("SELECT userA.uid, name, salary FROM userA JOIN listB ON userA.uid = listB.uid")
resultDF.foreach(x => {
val i = 1
})
resultDF.show()
resultDF.explain(true)
Thread.sleep(60 * 10 * 1000)
sparkContext.stop()
}
}
构建两个表
user_salary 表 中有两个字段 uid, salary
user 表中有两个字段 uid, name user 表中数据大量倾斜,大量uid =1
实际运行情况。
下图:所有作业切分
下图:JOB0流程
下图:JOB0 中 STAGE2 流程
下图:JOB0 中 STAGE2 流程 中 Task 计算,
可以看到是有数据倾斜的
下图:JOB 对应的SQL
下图:JOB 的SQL0 执行流程
下图:JOB 的SQL0 执行计划
1.broadcastJOIN (更类似于 Hive 中 Map JOIN)
首先,我比较推荐 broadcastJOIN , 相当于将数据广播到各个Executor 中 ,和 HIVE 的 Map JOIN 类似,相当于不产生数据 混洗的流程。
主要的做法就是调整 spark.sql.autoBroadcastJoinThreshold 超过小表的数据大小,将小表进行广播,默认的大小是 10485760, 即 10M。
package com.spark.test.offline.skewed_data
import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
/**
* Created by szh on 2020/6/5.
*/
object JOINSkewedData {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf
sparkConf
.setAppName("Union data test")
.setMaster("local[4]")
//.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
.set("spark.sql.shuffle.partitions", "3")
val spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate()
val sparkContext = spark.sparkContext
sparkContext.setLogLevel("WARN")
val idArr = Array[Int](1, 2, 3)
val userArr = new ArrayBuffer[(Int, String)]()
val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")
val threshold = 1000000
for (i <- 1 to threshold) {
var id = 10
if (i < (threshold * 0.9)) {
id = 1
} else {
id = i
}
val name = nameArr(Random.nextInt(5))
userArr.+=((id, name))
}
val rddA = sparkContext
.parallelize(userArr)
.map(x => Row(x._1, x._2))
val rddAStruct = StructType(
Array(
StructField("uid", IntegerType, nullable = true)
, StructField("name", StringType, nullable = true)
)
)
val rddADF = spark.createDataFrame(rddA, rddAStruct)
rddADF.createOrReplaceTempView("userA")
//spark.sql("CACHE TABLE userA")
//-----------------------------------------
//---------------------------------------
val arrList = new ArrayBuffer[(Int, Int)]
for (i <- 1 to threshold) {
var id = 10
if (i < 5) {
id = 1
} else {
id = i
}
val salary = Random.nextInt(100)
arrList.+=((id, salary))
}
spark
.createDataFrame(arrList).toDF("uid", "salary")
.createOrReplaceTempView("listB")
val resultDF = spark
.sql("SELECT userA.uid, name, salary FROM userA JOIN listB ON userA.uid = listB.uid")
resultDF.foreach(x => {
val i = 1
})
resultDF.show()
resultDF.explain(true)
Thread.sleep(60 * 10 * 1000)
sparkContext.stop()
}
}
看一下执行流程
下图:SQL Job切分
下图 Job0 STAGE 切分
下图:Job0流程
下图:Stage0
下图:各个task执行时间
可以看到消除了数据倾斜
下图 : 涉及到的SparkSQL
下图 : SQL 执行流程
下图:SQL执行计划
2.broadcast 后,join 相当于自己实现 broadcast join
代码:
package com.spark.test.offline.skewed_data
import org.apache.spark.SparkConf
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.types.{StructField, _}
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
/**
* Created by szh on 2020/6/5.
*/
object JOINSkewedData2 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf
sparkConf
.setAppName("JOINSkewedData")
.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
//.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
.set("spark.sql.shuffle.partitions", "3")
if (args.length > 0 && args(0).equals("ide")) {
sparkConf
.setMaster("local[3]")
}
val spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate()
val sparkContext = spark.sparkContext
sparkContext.setLogLevel("WARN")
sparkContext.setCheckpointDir("file:///D:/checkpoint/")
val userArr = new ArrayBuffer[(Int, String)]()
val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")
val threshold = 1000000
for (i <- 1 to threshold) {
var id = 10
if (i < (threshold * 0.9)) {
id = 1
} else {
id = i
}
val name = nameArr(Random.nextInt(5))
userArr.+=((id, name))
}
val rddA = sparkContext
.parallelize(userArr)
//spark.sql("CACHE TABLE userA")
//-----------------------------------------
//---------------------------------------
val arrList = new ArrayBuffer[(Int, Int)]
for (i <- 1 to (threshold * 0.1).toInt) {
val id = i
val salary = Random.nextInt(100)
arrList.+=((id, salary))
}
val rddB = sparkContext
.parallelize(arrList)
val broadData: Broadcast[Array[(Int, Int)]] = sparkContext.broadcast(rddB.collect())
import scala.util.control._
val resultRdd = rddA
.mapPartitions(arr => {
val broadVal = broadData.value
var rowArr = new ArrayBuffer[Row]()
val broadMap = new mutable.HashMap[Int, Int]()
for (tmpVal <- broadVal) {
broadMap.+=((tmpVal._1, tmpVal._2))
}
while (arr.hasNext) {
val x = arr.next
if (broadMap.contains(x._1)) {
rowArr.+=(Row(x._1, x._2, broadMap(x._1)))
}
}
//TODO : 测试代码
println(rowArr.size)
rowArr.iterator
})
// while (arr.hasNext) {
//
// val x = arr.next
// val loop = new Breaks
// var rRow: Row = null
// //var rRow: Option[Row] = None
//
// loop.breakable(
// for (tmpVal <- broadVal) {
// if (tmpVal._1 == x._1) {
// rRow = Row(tmpVal._1, x._2, tmpVal._2)
// //println(rRow)
// loop.break
// }
// }
// )
// if (rRow != null) {
// rowArr.+=(rRow)
// rRow = null
// }
// }
//
// println(rowArr.size)
//
// rowArr.iterator
// })
// .filter(x => {
// x match {
// case None => false
// case _ => true
// }
// })
val resultStruct = StructType(
Array(
StructField("uid", IntegerType, nullable = true)
, StructField("name", StringType, nullable = true)
, StructField("salary", IntegerType, nullable = true)
)
)
spark
.createDataFrame(resultRdd, resultStruct)
.createOrReplaceTempView("resultB")
val resultDF = spark
.sql("SELECT uid, name, salary FROM resultB")
resultDF.cache()
resultDF.checkpoint()
resultDF.foreach(x => {
val i = 1
})
println(resultDF.count())
resultDF.show()
resultDF.explain(true)
Thread.sleep(60 * 10 * 1000)
sparkContext.stop()
}
}
重要代码
val broadData: Broadcast[Array[(Int, Int)]] = sparkContext.broadcast(rddB.collect())
import scala.util.control._val resultRdd = rddA
.mapPartitions(arr => {val broadVal = broadData.value
var rowArr = new ArrayBuffer[Row]()
val broadMap = new mutable.HashMap[Int, Int]()for (tmpVal <- broadVal) {
broadMap.+=((tmpVal._1, tmpVal._2))
}while (arr.hasNext) {
val x = arr.next
if (broadMap.contains(x._1)) {
rowArr.+=(Row(x._1, x._2, broadMap(x._1)))
}
}//TODO : 测试代码
println(rowArr.size)rowArr.iterator
})
前提 : 小数据集合为维度表,并且维度表中有主键或者其他键作为唯一的标志。
Step1 那么将维度表进行广播
Step2 大表做MapPartition ,
Step3 MapPartition中, 获取 广播变量 维度表构建HashMap
Step4 MapPartition中, 遍历每个元素实现 JOIN / LEFT JOIN .... 的逻辑
注意:如果小表 JOIN 的 KEY 不是唯一的,建议对 KEY 先排序,再去实现JOIN 逻辑 , 较为复杂,这里不做演示!
执行效率 :
下图 :整体作业执行情况
下图 :JOB1 执行情况
下图 :JOB1 STAGE1 执行情况
可以看到已经消除了倾斜,没有发生 Shuffle !
3.数据膨胀,去前缀(操作复杂,不推荐)不知道为什么推崇这种做法?
具体的代码
package com.spark.test.offline.skewed_data
import org.apache.spark.SparkConf
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
/**
* Created by szh on 2020/6/5.
*/
object JOINSkewedData3 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf
sparkConf
.setAppName("JOINSkewedData")
.set("spark.sql.autoBroadcastJoinThreshold", "-1") //1M broadcastJOIN
//.set("spark.sql.autoBroadcastJoinThreshold", "1048576") //1M broadcastJOIN
//.set("spark.sql.autoBroadcastJoinThreshold", "104857600") //100M broadcastJOIN
.set("spark.sql.shuffle.partitions", "3")
if (args.length > 0 && args(0).equals("ide")) {
sparkConf
.setMaster("local[3]")
}
val spark = SparkSession.builder()
.config(sparkConf)
.getOrCreate()
val sparkContext = spark.sparkContext
sparkContext.setLogLevel("WARN")
val userArr = new ArrayBuffer[(Int, String)]()
val nameArr = Array[String]("sun", "zhen", "hua", "kk", "cc")
val threshold = 1000000
for (i <- 1 to threshold) {
var id = 10
if (i < (threshold * 0.9)) {
id = 1
} else {
id = i
}
val name = nameArr(Random.nextInt(5))
userArr.+=((id, name))
}
val rddA = sparkContext
.parallelize(userArr)
.map(x => Row(x._1, x._2))
val rddAStruct = StructType(
Array(
StructField("uid", IntegerType, nullable = true)
, StructField("name", StringType, nullable = true)
)
)
val fuction1 = (key: Int) => Random.nextInt(10)
val random = udf(fuction1)
val rddADF = spark.createDataFrame(rddA, rddAStruct)
rddADF
.withColumn("random", random(rddADF("uid")))
.createOrReplaceTempView("userA")
//spark.sql("CACHE TABLE userA")
//-----------------------------------------
//---------------------------------------
val arrList = new ArrayBuffer[(Int, Int)]
for (i <- 1 to threshold) {
var id = 10
if (i < 5) {
id = 1
} else {
id = i
}
val salary = Random.nextInt(100)
arrList.+=((id, salary))
}
val function2 = (key: Int) => "0,1,2,3,4,5,6,7,8,9"
val randArr = udf(function2)
val oldDF = spark
.createDataFrame(arrList)
.toDF("uid", "salary")
oldDF
.withColumn("rand_arr", randArr(oldDF("uid")))
.createOrReplaceTempView("listB")
val newDF = spark
.sql("SELECT uid, salary, CAST(rand_key2 AS INT) AS rand_key FROM listB LATERAL VIEW EXPLODE(SPLIT(rand_arr, ',')) AS rand_key2")
.createOrReplaceTempView("listB_new")
//spark.sql("CACHE TABLE listB_new")
val resultDF = spark.sql("SELECT userA.uid, userA.name, salary FROM userA JOIN listB_new ON userA.uid = listB_new.uid AND userA.random = listB_new.rand_key ")
// resultDF.cache()
//
// resultDF.foreach(x => {
// val i = 1
// })
//
println(resultDF.count())
// resultDF.show()
resultDF.explain(true)
Thread.sleep(60 * 10 * 1000)
sparkContext.stop()
}
}
其中体现在2个地方,
对其中一个表利用 withColumn 增加一个随即列,
val fuction1 = (key: Int) => Random.nextInt(10)
val random = udf(fuction1)val rddADF = spark.createDataFrame(rddA, rddAStruct)
rddADF
.withColumn("random", random(rddADF("uid")))
.createOrReplaceTempView("userA")
另一张表利用 explode 和 lateral view 等 udtf 函数,将原来的一列变为多列:
val function2 = (key: Int) => "0,1,2,3,4,5,6,7,8,9" val randArr = udf(function2) val oldDF = spark .createDataFrame(arrList) .toDF("uid", "salary") oldDF .withColumn("rand_arr", randArr(oldDF("uid"))) .createOrReplaceTempView("listB") val newDF = spark .sql("SELECT uid, salary, CAST(rand_key2 AS INT) AS rand_key FROM listB LATERAL VIEW EXPLODE(SPLIT(rand_arr, ',')) AS rand_key2") .createOrReplaceTempView("listB_new")
当然具体的API 并不唯一,可以利用 mapPartitions 完成相同的膨胀操作 !!
下图: 生成的JOB 执行计划
下图: JOB0 的Stage 切分
下图:具体的连接的Shuffle task 运算耗时
可以看到消除了 数据倾斜的问题!!
下图:具体的SQL 执行流程,可以看到还是通过 SortMergeJOIN 完成的两张表的连接