文章目录
Spark代码可读性与性能优化——示例六(GroupBy、ReduceByKey)
1. 普通常见优化示例
1.1 错误示例 groupByKey
import org.apache.spark.{SparkConf, SparkContext}
object GroupNormal {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupNormal")
val sc = new SparkContext(conf)
// 数据可能有几亿条,此处只做模拟示例
val dataRDD = sc.parallelize(List(
("hello", 2),
("java", 7),
("where", 1),
("rust", 2),
// 中间还有很多数据,不做展示
("scala", 1),
("java", 1),
("black", 9)
))
// 做一个词频统计
val result = dataRDD.groupByKey()
.mapValues(_.sum)
.sortBy(_._2, false)
result.take(10).foreach(println)
sc.stop()
}
}
1.2 正确示例 reduceByKey
// 修改此部分groupByKey代码为reduceByKey
val result = dataRDD
.reduceByKey(_ + _)
.sortBy(_._2, false)
result.take(10).foreach(println)
2. 高级优化
2.0. 需求:统计历年全国高考生中数学成绩前100名
2.1 数据示例
id | chinese | math | english | year |
---|---|---|---|---|
3412312 | 121 | 115 | 134 | 2018 |
5231211 | 103 | 131 | 114 | 2010 |
…… | …… | …… | …… | …… |
2342354 | 134 | 105 | 124 | 2014 |
- 共计约2亿条数据
- 数据存于Hive中,表名tb_student_score,id值(唯一)代表学生,chinese代表语文,math代表数学,english代表英语
2.2 存在问题的代码示例
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* 数据分组错误示例
*
* @author ALion
* @version 2019/5/15 22:33
*/
object GroupDemo {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("GroupDemo")
val spark = SparkSession.builder()
.config(conf)
.enableHiveSupport()
.getOrCreate()
// 获取原始数据
val studentDF = spark.sql(
"""
|SELECT *
|FROM tb_student_score
|WHERE id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL
""".stripMargin)
// 开始进行分析
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
})
.groupByKey() // 按年分组
.mapValues(_.toSeq.sortWith(_._2 > _._2).take(100)) // 根据math对每个人进行降序排序,最后获取前100的人
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
spark.stop()
}
}
- 首先,可以肯定的是代码逻辑毫无问题,能够满足业务需求。
- 其次,这部分代码又存在很大的性能问题:
spark.sql("SELECT * FROM tb_student_score")
这种形势读取表中数据较慢,有更快的方式- groupByKey处,发生shuffle,大量数据被分到对应的年份的节点中,然后每个节点使用单线程在各年对应的所有数据中对学生进行排序,最后获取前100名
- groupByKey处的shuffle可能发生数据倾斜,可能存在部分年份的数据不全或参考人数较少,而部分年份数据较多
- 另外,直接使用SQL的方案已附在文章末尾
2.3 如何解决代码中的问题?
- 首先,读取表可以采用DataFrame的API,指定Schema,能够加速表的读取
val tbSchema = StructType(Array(
StructField("id", LongType, true),
StructField("chinese", IntegerType, true),
StructField("math", IntegerType, true),
StructField("english", IntegerType, true),
StructField("year", IntegerType, true)
))
// 获取原始数据
val studentDF = spark.read.schema(tbSchema).table("tb_student_score")
.where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL")
- 其次,关于groupBy发生shuffle的问题以及排序的问题。似乎数据如果不按年份分组,针对每年所有的分数统一排序,就没有其他办法。因为待排序的数据不在一起好像就不能完整的排序啊?那还怎么谈取前100名啊?
- 其实不然,想想我们是不是可以先在每个数据分块本地排序一次获取前100名,最后将所有的前100汇总,进行一次总的排序获取总的前100名?这样的话,充分利用了每个分块的并行计算,提前做了部分排序,当数据shuffle的时候每个分块数据就只有100条,最后汇总进行一次排序的数据量就非常小了!其实这就是归并排序的思想,感兴趣的朋友可以搜索‘归并排序’看看。
- 优化后的示例代码如下:
// 开始进行分析
val resultRDD = studentDF.rdd
.mapPartitions {
// 自己实现时,如果为了性能更好,不建议这样的函数式写法
// 这里只是为了方便看
_.map { row =>
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
}.toArray
.groupBy(_._1) // 先在每个分块前,获取历年的数学前100名,减少后续groupBy的shuffle数据量
.mapValues(_.map(_._2).sortWith(_._2 > _._2).take(100))
.toIterator
}.groupByKey() // 最后获取所有分块的前100名,再次排序,计算总的前100名
.mapValues(_.flatten.toSeq.sortWith(_._2 > _._2).take(100))
// 触发Action,展示部分统计结果
resultRDD.take(10).foreach(println)
-
上述代码,已经完成功能实现。那么,这样的代码是否是最好的呢?答案是否定的。因为当前的排序是针对每个分块(Partition)的,一个Executor上有多个分块,每个分块有前100条数据需要shuffle,显然如果一个Executor一共只有100条数据需要shuffle才是最理想的!如果我们能有办法同时操纵每个Executor上的所有数据,获取前100条数据,那该多好啊!
-
我们想要的排序流程示意图如下:
-
然而,Spark并没提供一个类似mapPartition的可以对Executor上所有分块统一操作的算子(不然的话,我们就可以像mapPartion那样统计每Executor的前100名了)。不过我们有一个算子reduceByKey,它会在每个节点合并数据后再shuffle到一个节点进行最后的合并,这种行为似乎与我们需要的逻辑类似,不过好像又有那么一点不一样。
-
你可能会说reduceByKey是合并,而我们的需求是排序啊!!!是的,这看上去似乎有点矛盾。
-
事实上,这样是行得通的:
- 首先,让我们假想有这样一个集合类型A(内部是可排序的,并且只能拥有前100的数据,多余的会被删除)
- 接着,把每个元素(id,math)转换成含有一个元素的集合A
- 最后,使用reduceByKey,将每个集合依次相加合并!!!没错!就是合并!这样最后一个集合就是包含前100名的集合了。
-
这样一个集合类型A,似乎在Scala、Java中不存在,不过有一个TreeSet能保证内部有序,我们可以在数据合并后手动提取前100,这样就可以了(另外,你也可以自己实现这样一个集合:3)
-
第一步,先将id和math转为一个对象,并为这个对象实现equals、hashCode、compareTo方法,保证后续在TreeSet中的排序不会出问题。另外,再实现一个toString方法,方便我们查看打印效果!:)
- Person.class 代码 (因为Java比较易懂、易写这几个方法,这里优先采用Java的形式,后面会附上Scala对应的实现类)
public class Person implements Comparable<Person>, Serializable { private long id; private int math; public Person(long id, int math) { this.id = id; this.math = math; } @Override public int compareTo(Person person) { int result = person.math - this.math; // 降序 if (result == 0) { result = person.id - this.id > 0 ? 1 : -1; } return result; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Person person = (Person) o; return id == person.id; } @Override public int hashCode() { return (int) (id ^ (id >>> 32)); } @Override public String toString() { return "Person{" + "id='" + id + '\'' + ", math=" + math + '}'; } }
- TreeSet 使用示例
import scala.collection.immutable.TreeSet object Demo { def main(args: Array[String]): Unit = { val set = TreeSet[Person]( new Person(1231232L, 108), new Person(3214124L, 116), new Person(1321313L, 121), new Person(6435235L, 125) ) // 获取前3名 for (elem <- set.take(3)) { println(s"--> elem = $elem") } } }
-
第二步,将原先的id、math封装为TreeSet
studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
- 最后,使用reduceByKey合并所有数据,得到前100名的结果
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, TreeSet(new Person(id, math)))
})
.reduceByKey((set1, set2) => set1 ++ set2 take 100) // 依次合并2个Set,并只保留前100
resultRDD.take(10).foreach(println)
- Nice!!! 这样,我们就同时解决了排序问题和数据倾斜问题!
- 进一步优化(aggregateByKey)
- 细心的朋友应该已经发现了,reduceByKey之前的map为每条的数据都生成了一个TreeSet,这样会大大增加内存消耗。
- 其实,我们只想要每个节点放一个可变的TreeSet(并且还能一直只存前100)。这样内存消耗就会更小!
- 那么我们该如何做呢?设计一个MyTreeSet,采用aggregateByKey复用同一个Set,简略的示例如下:
- MyTreeSet(简易实现,针对mutable.TreeSet封装)
import scala.collection.mutable class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) { val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*) def +=(elem: A): MyTreeSet[A] = { this add elem this } def add(elem: A): Unit = { set.add(elem) // 删除排在最后的多余元素 check10Size() } def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = { that.set.foreach(e => this add e) this } def check10Size(): Unit = { // 如果超过了firstNum个,就删除 if (set.size > firstNum) { set -= set.last } } override def toString: String = set.toString } object MyTreeSet { def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem) // 默认保留前10 def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem) }
- Spark部分代码
val resultRDD = studentDF.rdd .map(row => { val id = row.getLong(row.fieldIndex("id")) val math = row.getInt(row.fieldIndex("math")) val year = row.getInt(row.fieldIndex("year")) (year, new Person(id, math)) }).aggregateByKey(MyTreeSet[Person](100)) ( (set, v) => set += v, (set1, set2) => set1 ++= set2 )
2.4 最终代码,以及其他附件代码
- 最终代码
import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} import scala.collection.immutable.TreeSet object GroupDemo3 { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("GroupDemo") val spark = SparkSession.builder() .config(conf) .enableHiveSupport() .getOrCreate() val tbSchema = StructType(Array( StructField("id", LongType, true), StructField("chinese", IntegerType, true), StructField("math", IntegerType, true), StructField("english", IntegerType, true), StructField("year", IntegerType, true) )) // 获取原始数据 val studentDF = spark.read.schema(tbSchema).table("tb_student_score") .where("id IS NOT NULL AND math IS NOT NULL AND year IS NOT NULL") // 开始进行分析 val resultRDD = studentDF.rdd .map(row => { val id = row.getLong(row.fieldIndex("id")) val math = row.getInt(row.fieldIndex("math")) val year = row.getInt(row.fieldIndex("year")) (year, new Person(id, math)) }).aggregateByKey(MyTreeSet[Person](100)) ( (set, v) => set += v, (set1, set2) => set1 ++= set2 ) // 依次合并2个Set,并只保留前100 // 触发Action,展示部分统计结果 resultRDD.take(10).foreach(println) spark.stop() } }
- Person的Scala实现
class PersonScala(val id: Long, val math: Int) extends Ordered[PersonScala] with Serializable { override def compare(that: PersonScala): Int = { var result = that.math - this.math // 降序 if (result == 0) result = if (that.id - this.id > 0) 1 else -1 result } override def equals(obj: Any): Boolean = { obj match { case person: PersonScala => this.id == person.id case _ => false } } override def hashCode(): Int = (id ^ (id >>> 32)).toInt override def toString: String = "Person{" + "id=" + id + ", math=" + math + '}' } object PersonScala { def apply(id: Long, math: Int): PersonScala = new PersonScala(id, math) }
- 示例——使用SQL获取历年数学的前100名(简单,但性能一般,且存在数据倾斜的可能)
def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("GroupDemo") val spark = SparkSession.builder() .config(conf) .enableHiveSupport() .getOrCreate() // 使用sql分析 val resultDF = spark.sql( """ |SELECT year,id,math |FROM ( | SELECT year,id,math,ROW_NUMBER() OVER (PARTITION BY year ORDER BY math DESC) rank | FROM tb_student_score |) g |WHERE g.rank <= 100 """.stripMargin) // 触发Action,展示部分统计结果 resultDF.show() spark.stop() }