文章目录
Spark代码可读性与性能优化——示例七(构建聚合器,以用于复杂聚合)
- 接第六篇,如未看过的同学,请先看Spark代码可读性与性能优化——示例六(groupBy、reduceByKey、aggregateByKey)
1. 多列聚合
1.1 前情提要
- 第六篇中,有个需求“统计历年全国高考生中数学成绩前100名”,咱们已经完成了。可是,突然领导又来需求了:
- 领导:“你再统计一下语文、英语、物理……的前100名。下班前给我!”
- 你:“~!@#¥%……”
- 然后,你还是去统计(想了想,也就那么几个科目,大不了全部都跑一次!),终于在下班的时候统计完了,上交成果!!!Nice!^_^
- 第二天,领导又找了张表,还是要求你统计每列字段的前100名,然而这张表有50多个字段。
- 你:“……”
- 然后你就去跑50多次????估计加班都跑不完!!所以,咱们得想个办法,能不能一次就统计完?这个时候,就要改写代码,学会进行多列统计(顺便一提,sql boy是写不出来只进行一次遍历就统计完所有列前100的SQL的!^。^他们还是得跑50多次遍历表……终于体现出代码狗的优势了!!!=O=)
1.2 尝试进行本地多列聚合
- Person类,就用Scala版的吧
class Person(val id: Long, val grade: Int) extends Ordered[Person] with Serializable {
override def compare(that: Person): Int = {
var result = that.grade - this.grade // 降序
if (result == 0)
result = if (that.id - this.id > 0) 1 else -1
result
}
override def equals(obj: Any): Boolean = {
obj match {
case person: Person => this.id == person.id
case _ => false
}
}
override def hashCode(): Int = (id ^ (id >>> 32)).toInt
override def toString: String = "Person{" + "id=" + id + ", grade=" + grade + "}"
}
object Person {
def apply(id: Long, grade: Int): Person = new Person(id, grade)
}
- 多列聚合的时候,每个成绩分开聚合就OK了。不过咱们需要找到一个类,能够装下3种排名,不过这次似乎真的没有这种类了。
- 现在,咱们需要自己编写一个聚合器类,用作聚合。而这个类的中心属性应该分别是数学前10集合、语文前10集合、英语前10集合,每次合并2个类时,分别将3个集合一一合并!示例如下:
/**
* Description: 数学、语文、英语的前NUM名的聚合器
* <br/>
* Date: 2019/11/27 1:39
*
* @author ALion
*/
class PersonAggregator(val mathSet: MyTreeSet[Person],
val chineseSet : MyTreeSet[Person],
val englishSet : MyTreeSet[Person]) {
/**
* 向聚合器添加单个元素
* @param element (人的id, 数学, 语文, 英语)
* @return this PersonAggregator
*/
def +=(element: (Long, Int, Int, Int)): PersonAggregator = {
this.mathSet += Person(element._1, element._2)
this.chineseSet += Person(element._1, element._3)
this.englishSet += Person(element._1, element._4)
this
}
/**
* 聚合成绩的方法
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++=(that: PersonAggregator): PersonAggregator = {
this.mathSet ++= that.mathSet
this.chineseSet ++= that.chineseSet
this.englishSet ++= that.englishSet
this
}
override def toString: String =
"PersonAggregator{" +
"mathSet=" + mathSet +
", chineseSet=" + chineseSet +
", englishSet=" + englishSet +
'}'
}
object PersonAggregator {
def apply(): PersonAggregator =
new PersonAggregator(MyTreeSet[Person](), MyTreeSet[Person](), MyTreeSet[Person]())
}
- 最后,在本地写个测试代码,试试看
object Demo {
def main(args: Array[String]): Unit = {
// 此处,我让MyTreeSet取的前2名,修改后面附录的MyTreeSet即可
val aggregator1 = PersonAggregator()
aggregator1 += (1, 80, 92, 100) += (2, 85, 90, 78) += (3, 88, 95, 67)
println(s"aggregator1 = $aggregator1")
}
}
- 我的输出结果如下,没问题!^_^
aggregator = PersonAggregator{mathSet=TreeSet(Person{id=3, grade=88}, Person{id=2, grade=85}), chineseSet=TreeSet(Person{id=3, grade=95}, Person{id=1, grade=92}), englishSet=TreeSet(Person{id=1, grade=100}, Person{id=2, grade=78})}
- MyTreeSet在最后附录处
1.3 多列聚合最终代码
- 那么,修改我们Spark统计部分的主体代码,开始运行吧!--------------------->
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val chinese = row.getInt(row.fieldIndex("chinese"))
val english = row.getInt(row.fieldIndex("english"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math, chinese, english))
})
.aggregateByKey(PersonAggregator())(
(agg, v) => agg += v,
(agg1, agg2) => agg1 ++= agg2
) // 依次合并2个聚合器PersonAggregator
2. 单列多重聚合(简单示例)
2.1 前情提要
- 前面咱们已经写出了多列聚合的代码,愉快的下了班……然而万恶的需求又来了
- 这次,领导说:“给我统计一下每年数学前100名的,顺便算下每年的平均数学成绩……哦,还有数学考了0分的有多少个!”
- 你:“~!@#¥%……”(我的内心是崩溃的! TAT)
- 然而……毕竟只是只程序狗,还是得做。
- 不过,咱们做就要做得Perfect,还是一次统计完!不像sql boy一样偷偷摸摸地搞多次,浪费集群资源!真是可耻!
2.2 尝试进行本地单列多重聚合
- Person类还是前面那个,就不贴代码了
- 单列多重聚合其实和多列聚合相似,关键是抓住业务需求(前100名,平均成绩,考0分的人数),以此为聚合点,构建你的聚合器(聚合点+聚合算法),看代码
/**
* Description: 数学前100名,数学平均成绩,数学为0分的人数 -> 聚合器
*
* @note {{{
* 前100名 -> mathSet
* 分数之和 -> totalGrade
* 总人数 -> totalCount
* 平均成绩 -> totalGrade / totalCount (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
* 0分的人数 -> zeroCount
* }}}
*
* Date: 2019/11/27 1:39
* @author ALion
*/
class PersonAggregator2(val mathSet: MyTreeSet[Person],
var totalGrade: Long, var totalCount: Long,
var zeroCount: Long) {
/**
* 向聚合器添加单个元素
* @param element (人的id, 数学)
* @return this PersonAggregator
*/
def +=(element: (Long, Int)): PersonAggregator2 = {
this.mathSet += Person(element._1, element._2)
this.totalGrade += element._2
this.totalCount += 1
if (element._2 == 0) this.zeroCount += 1
this
}
/**
* 聚合成绩、人数的方法
*
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++=(that: PersonAggregator2): PersonAggregator2 = {
this.mathSet ++= that.mathSet
this.totalGrade += that.totalGrade
this.totalCount += that.totalCount
this.zeroCount += that.zeroCount
this
}
/**
* 计算平均值
*/
def calcAVG(): Double = {
totalGrade / totalCount.toDouble
}
override def toString: String =
"PersonAggregator2{" +
"mathSet=" + mathSet +
", avgGrade=" + calcAVG() +
", zeroCount=" + zeroCount +
'}'
}
object PersonAggregator2 {
def apply(): PersonAggregator2 =
new PersonAggregator2(MyTreeSet[Person](), 0, 0 ,0)
}
- 最后,在本地写个测试代码
import scala.collection.immutable.TreeSet
object Demo {
def main(args: Array[String]): Unit = {
val aggregator2 = PersonAggregator2()
aggregator2 += (1, 80) += (2, 0) += (3, 0)
println(s"aggregator2 = $aggregator2")
}
}
- 结果如下。Perfect! ^_^Just a piece of cake!
aggregator2 = PersonAggregator2{mathSet=TreeSet(Person{id=1, grade=80}, Person{id=3, grade=0}, Person{id=2, grade=0}), avgGrade=26.666666666666668, zeroCount=2}
2.3 单列多重聚合最终代码
- Spark统计部分的主体代码如下
val resultRDD = studentDF.rdd
.map(row => {
val id = row.getLong(row.fieldIndex("id"))
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (id, math))
})
.aggregateByKey(PersonAggregator2())(
(agg, v) => agg += v,
(agg1, agg2) => agg1 ++= agg2
)
3. 单列多重聚合(复杂示例)
3.1 新的需求
- 添加对数学的均方根误差(RMSE)的统计
3.2 聚合算法分析
-
RMSE的计算公式: 1 m ∑ i = 1 m ( x i − x − ) 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i} - _x^{-})^2} m1i=1∑m(xi−x−)2
-
咋一看上去似乎不可能能够一次性统计完,因为似乎得先算出平均数,才能继续计算RMSE的值啊!你的思路或许是这样的:
- 第一步,求均值、总数
- 第二步,对所有值与均值的差的方求和,然后将和除以总数,再开方
-
上面的逻辑没有问题,但是真的就不能一次完成聚合吗?
-
让我们先尝试对聚合算法进行拆解(当然有的算法确实没法拆解),对RMSE的算法进行转换,过程如下:
- 1 m ∑ i = 1 m ( x i − x − ) 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i} - _x^{-})^2} m1i=1∑m(xi−x−)2
- ⇒ t r a n s f o r m \xRightarrow{transform} transform
- 1 m ∑ i = 1 m ( x i 2 − 2 x i x − + x − 2 ) \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i}^2 - 2{x_i}_x^{-} + {_x^{-}}^2)} m1i=1∑m(xi2−2xix−+x−2)
- ⇒ t r a n s f o r m \xRightarrow{transform} transform
- 1 m ∑ i = 1 m x i 2 − 2 x − m ∑ i = 1 m x i + x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - \frac{2_x^{-}}{m}\sum_{i=1}^{m} {x_i} + {_x^{-}}^2} m1i=1∑mxi2−m2x−i=1∑mxi+x−2
- ⇒ t r a n s f o r m \xRightarrow{transform} transform
- 1 m ∑ i = 1 m x i 2 − 2 x − 2 + x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - 2{_x^{-}}^2 + {_x^{-}}^2} m1i=1∑mxi2−2x−2+x−2
- ⇒ t r a n s f o r m \xRightarrow{transform} transform
- 1 m ∑ i = 1 m x i 2 − x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - {_x^{-}}^2} m1i=1∑mxi2−x−2
- 最后,我们拆成了根号内的两部分,分别由以下参数组成:
- m 代表 数据的总数量
- ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 ∑i=1mxi2 代表 对所有值的方求和
- x − _x^{-} x− 代表 所有值的均值(等于 ∑ i = 1 m x i \sum_{i=1}^{m} x_i ∑i=1mxi除以m)
-
现在来看,显然简单了,你只需要找到m、 ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 ∑i=1mxi2、 ∑ i = 1 m x i \sum_{i=1}^{m} x_i ∑i=1mxi即可
3.3 代码编写示例
- Person类不变
- 编写聚合器Aggregator,记住关键点在于:
- 分别聚合出m、 ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 ∑i=1mxi2、 ∑ i = 1 m x i \sum_{i=1}^{m} x_i ∑i=1mxi的值
- 套用刚才最后得出的公式,计算出RMSE
/**
* Description: 数学平均成绩,RMSE -> 聚合器
*
* @note {{{
* 分数之和 -> totalGrade
* 总人数 -> totalCount
* 平均成绩 -> totalGrade / totalCount
* 所有分数平方的和 -> sqrtSum
* (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
* }}}
*
* Date: 2019/11/27 1:39
* @author ALion
*/
class PersonAggregator3(var totalGrade: Long, var totalCount: Long, var powSum: Long) {
/**
* 聚合成绩、人数的方法
*
* @param that 另一个聚合器
* @return this PersonAggregator
*/
def ++(that: PersonAggregator3): PersonAggregator3 = {
this.totalGrade += that.totalGrade
this.totalCount += that.totalCount
this.powSum += that.powSum
new PersonAggregator3(totalGrade, totalCount, powSum)
}
/**
* 计算平均值
*/
def calcAVG(): Double = {
totalGrade / totalCount.toDouble
}
/**
* 根据化简后的公式计算 RMSE
*/
def calcRMSE(): Double = {
val avg = calcAVG()
Math.sqrt(powSum / totalCount.toDouble - avg * avg)
}
// 懂lazy的话,就按下面的写法写
// lazy val avg: Double = totalGrade / totalCount.toDouble
//
// lazy val rmse: Double = Math.sqrt(sqrtSum / totalCount.toDouble - avg * avg)
override def toString: String =
"PersonAggregator3{" +
"avgGrade=" + calcAVG() +
", rmse=" + calcRMSE() +
'}'
}
object PersonAggregator3 {
def apply(math: Int): PersonAggregator3 =
new PersonAggregator3(math, 1, math * math)
}
- Spark统计部分的主体代码
val resultRDD = studentDF.rdd
.map(row => {
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
// 此处不用为每个元素生成一个大对象(集合等),无需使用aggregateByKey,你可以试着写一下:)
(year, PersonAggregator3(math))
}).reduceByKey(_ ++ _)
3.4 利用SparkSQL自定义聚合函数求解RMSE,见附录
4. 多列多重聚合
- 噢???这个已经不用说怎么做了吧?
- 领导:“对了,刚才的求数学前100名、均值、0分人数、RMSE的统计,再给我把语文、英语、物理……也整一个!”
5. 总结与整理
- 编写聚合器Aggregator的目的:(不写聚合器类,也能完成需求,请看附录)
- 更好的封装性,易于扩展
- 防止与聚合无关的代码发生耦合,以做到"高内聚,低耦合"
- 有利于编写更加复杂的聚合逻辑
- 代码更优雅、简洁
- 编写聚合器Aggregator的关键点:
- 以需要聚合的点(最值、均值、计数等)为该聚合器的属性或者全局变量
- 构建聚合算法(例如示例中的++,你也可以写成其他名字),编写你需要的聚合逻辑(除了一般的求均值、最值等聚合外,你还可以编写各种复杂的聚合需求!)
- 编写聚合器之前,你可以先找找Java是否有可以直接使用的类,以用于聚合:
- 例如,第六篇的TreeSet,在reduceByKey处合并Set后,再取前n名
- 例如,考虑“同时统计某个表所有字段对应的值的总数、去重后的总数,并要求对应字段值非空”时,你可以使用HashMap(key存字段值,value存该字段值的数量,你只需要在reduceByKey处编写一个合并Map的方法)
- 如果一个业务首先想到的是需要写groupByKey算子解决,那么你可以尝试使用本篇提供的思路来解决问题。另外,想法不要太机械,这只是一个通用的方法示例,不同的业务有不同的需求,可以按照不同的方式实现,很多业务需求不用这样编写聚合器也能完成聚合任务!!!(你可以试试前面的“统计字段值的总数、去重后的总数”,有很多种方法实现快速聚合)
- 不是一定要根据key分组聚合的话,你还可以尝试直接使用treeReduce方法聚合
- 不适用的场景:
- 这种聚合方式只适用于多的数据聚合成少量的聚合(不然也不叫聚合了^_^哈哈!如果最后生成的数据量没有变化,无论怎么优化,也毫无意义,因为这些数据始终存在,必然占用空间)。例如,你要求排序,并保留所有数据的顺序,而不是取前n名。
- 单次聚合分析必须要所有数据才能得出结果的业务是不适用的。(或者你可以考虑如何分化该聚合业务需求,很多业务并不是真的一定要得到所有数据后才能开始聚合)
- 本身需要多次串行的聚合逻辑的业务不直接适用。(多想想再做决定,例如RMSE就是看起来必须先计算平均值再求结果,实际上转化计算公式后就不一样了)
- 关于现在网上有很多人传“用随机数的方式进行两次shuffle解决数据倾斜的复杂聚合问题”的方案,这种方案确实是有不错的效果,不过就是变成了2次shuffle,把问题弄复杂了
6. 附录
- 不用这里的聚合器,也是可以实现我们的需求的,示例如下:
// 针对前面求RMSE的业务
val resultRDD = studentDF.rdd
.map { row =>
val math = row.getInt(row.fieldIndex("math"))
val year = row.getInt(row.fieldIndex("year"))
(year, (math, 1, math * math))
}.reduceByKey { case (t1, t2) =>
(t1._1 + t2._1, t1._2 + t2._2, t1._3 + t2._3)
}.mapValues { case (totalGrade, totalCount, powSum) =>
val avg = totalGrade / totalCount.toDouble
val rmse = Math.sqrt(powSum / totalCount.toDouble - avg * avg)
(avg, rmse)
}
- 利用SparkSQL自定义聚合函数求解RMSE
/**
* Description: 自定义求RMSE的聚合函数
*
* @example {{{
* spark.udf.register("rmseUDAF", new MyRmseUDAF())
* spark.sql("SELECT rmseUDAF(math) FROM tb_person")
* }}}
*
* @author ALion
*/
class MyRmseUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = StructType(
StructField("math", LongType) :: Nil
)
override def bufferSchema: StructType = StructType(
StructField("totalGrade", LongType) ::
StructField("totalCount", LongType) ::
StructField("powSum", LongType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
buffer.update(2, 0L)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// totalGrade
buffer.update(0, buffer.getLong(0) + input.getLong(0))
// totalCount
buffer.update(1, buffer.getLong(1) + 1)
// powSum
buffer.update(2, buffer.getLong(2) + input.getLong(0) * input.getLong(0))
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// totalGrade
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
// totalCount
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
// powSum
buffer1.update(2, buffer1.getLong(2) + buffer2.getLong(2))
}
override def evaluate(buffer: Row): Any = {
val totalGrade = buffer.getLong(0)
val totalCount = buffer.getLong(1).toDouble
val powSum = buffer.getLong(2)
val avg = totalGrade / totalCount
// RMSE
Math.sqrt(powSum / totalCount - avg * avg)
}
}
- MyTreeSet(简易实现,针对mutable.TreeSet封装)
import scala.collection.mutable
class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {
val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)
def +=(elem: A): MyTreeSet[A] = {
this add elem
this
}
def add(elem: A): Unit = {
set.add(elem)
// 删除排在最后的多余元素
check10Size()
}
def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
that.set.foreach(e => this add e)
this
}
def check10Size(): Unit = {
// 如果超过了firstNum个,就删除
if (set.size > firstNum) {
set -= set.last
}
}
override def toString: String = set.toString
}
object MyTreeSet {
def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem)
def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
}