Spark代码可读性与性能优化——示例七(构建聚合器,以用于复杂聚合)

31 篇文章 1 订阅
27 篇文章 3 订阅

Spark代码可读性与性能优化——示例七(构建聚合器,以用于复杂聚合)

1. 多列聚合

1.1 前情提要

  • 第六篇中,有个需求“统计历年全国高考生中数学成绩前100名”,咱们已经完成了。可是,突然领导又来需求了:
    • 领导:“你再统计一下语文、英语、物理……的前100名。下班前给我!”
    • 你:“~!@#¥%……”
    • 然后,你还是去统计(想了想,也就那么几个科目,大不了全部都跑一次!),终于在下班的时候统计完了,上交成果!!!Nice!^_^
    • 第二天,领导又找了张表,还是要求你统计每列字段的前100名,然而这张表有50多个字段。
    • 你:“……”
  • 然后你就去跑50多次????估计加班都跑不完!!所以,咱们得想个办法,能不能一次就统计完?这个时候,就要改写代码,学会进行多列统计(顺便一提,sql boy是写不出来只进行一次遍历就统计完所有列前100的SQL的!^。^他们还是得跑50多次遍历表……终于体现出代码狗的优势了!!!=O=)

1.2 尝试进行本地多列聚合

  • Person类,就用Scala版的吧
class Person(val id: Long, val grade: Int) extends Ordered[Person] with Serializable {

  override def compare(that: Person): Int = {
    var result = that.grade - this.grade // 降序
    if (result == 0)
      result = if (that.id - this.id > 0) 1 else -1
    result
  }

  override def equals(obj: Any): Boolean = {
    obj match {
      case person: Person => this.id == person.id
      case _ => false
    }
  }

  override def hashCode(): Int = (id ^ (id >>> 32)).toInt

  override def toString: String = "Person{" + "id=" + id + ", grade=" + grade + "}"

}

object Person {

  def apply(id: Long, grade: Int): Person = new Person(id, grade)

}
  • 多列聚合的时候,每个成绩分开聚合就OK了。不过咱们需要找到一个类,能够装下3种排名,不过这次似乎真的没有这种类了。
  • 现在,咱们需要自己编写一个聚合器类,用作聚合。而这个类的中心属性应该分别是数学前10集合、语文前10集合、英语前10集合,每次合并2个类时,分别将3个集合一一合并!示例如下:
/**
  * Description: 数学、语文、英语的前NUM名的聚合器
  * <br/>
  * Date: 2019/11/27 1:39
  *
  * @author ALion
  */
class PersonAggregator(val mathSet: MyTreeSet[Person],
                       val chineseSet : MyTreeSet[Person],
                       val englishSet : MyTreeSet[Person]) {

  /**
    * 向聚合器添加单个元素
    * @param element (人的id, 数学, 语文, 英语)
    * @return this PersonAggregator
    */
  def +=(element: (Long, Int, Int, Int)): PersonAggregator = {
    this.mathSet += Person(element._1, element._2)
    this.chineseSet += Person(element._1, element._3)
    this.englishSet += Person(element._1, element._4)

    this
  }

  /**
    * 聚合成绩的方法
    * @param that 另一个聚合器
    * @return this PersonAggregator
    */
  def ++=(that: PersonAggregator): PersonAggregator = {
    this.mathSet ++= that.mathSet
    this.chineseSet ++= that.chineseSet
    this.englishSet ++= that.englishSet

    this
  }

  override def toString: String =
    "PersonAggregator{" +
      "mathSet=" + mathSet +
      ", chineseSet=" + chineseSet +
      ", englishSet=" + englishSet +
      '}'

}

object PersonAggregator {

  def apply(): PersonAggregator =
    new PersonAggregator(MyTreeSet[Person](), MyTreeSet[Person](), MyTreeSet[Person]())

}
  • 最后,在本地写个测试代码,试试看
object Demo {

  def main(args: Array[String]): Unit = {
	// 此处,我让MyTreeSet取的前2名,修改后面附录的MyTreeSet即可
    val aggregator1 = PersonAggregator()
    aggregator1 += (1, 80, 92, 100) += (2, 85, 90, 78) += (3, 88, 95, 67)
    println(s"aggregator1 = $aggregator1")

  }

}
  • 我的输出结果如下,没问题!^_^
aggregator = PersonAggregator{mathSet=TreeSet(Person{id=3, grade=88}, Person{id=2, grade=85}), chineseSet=TreeSet(Person{id=3, grade=95}, Person{id=1, grade=92}), englishSet=TreeSet(Person{id=1, grade=100}, Person{id=2, grade=78})}
  • MyTreeSet在最后附录处

1.3 多列聚合最终代码

  • 那么,修改我们Spark统计部分的主体代码,开始运行吧!--------------------->
    val resultRDD = studentDF.rdd
        .map(row => {
          val id = row.getLong(row.fieldIndex("id"))
          val math = row.getInt(row.fieldIndex("math"))
          val chinese = row.getInt(row.fieldIndex("chinese"))
          val english = row.getInt(row.fieldIndex("english"))
          val year = row.getInt(row.fieldIndex("year"))

          (year, (id, math, chinese, english))
        })
        .aggregateByKey(PersonAggregator())(
          (agg, v) => agg += v,
          (agg1, agg2) => agg1 ++= agg2
        ) // 依次合并2个聚合器PersonAggregator

2. 单列多重聚合(简单示例)

2.1 前情提要

  • 前面咱们已经写出了多列聚合的代码,愉快的下了班……然而万恶的需求又来了
    • 这次,领导说:“给我统计一下每年数学前100名的,顺便算下每年的平均数学成绩……哦,还有数学考了0分的有多少个!”
    • 你:“~!@#¥%……”(我的内心是崩溃的! TAT)
    • 然而……毕竟只是只程序狗,还是得做。
  • 不过,咱们做就要做得Perfect,还是一次统计完!不像sql boy一样偷偷摸摸地搞多次,浪费集群资源!真是可耻!

2.2 尝试进行本地单列多重聚合

  • Person类还是前面那个,就不贴代码了
  • 单列多重聚合其实和多列聚合相似,关键是抓住业务需求(前100名,平均成绩,考0分的人数),以此为聚合点,构建你的聚合器(聚合点+聚合算法),看代码
/**
  * Description: 数学前100名,数学平均成绩,数学为0分的人数 -> 聚合器
  *
  * @note {{{
  *      前100名 -> mathSet
  *      分数之和 -> totalGrade
  *      总人数 -> totalCount
  *      平均成绩 -> totalGrade / totalCount (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
  *      0分的人数 -> zeroCount
  * }}}
  *
  * Date: 2019/11/27 1:39
  * @author ALion
  */
class PersonAggregator2(val mathSet: MyTreeSet[Person],
                        var totalGrade: Long, var totalCount: Long,
                        var zeroCount: Long) {
  /**
    * 向聚合器添加单个元素
    * @param element (人的id, 数学)
    * @return this PersonAggregator
    */
  def +=(element: (Long, Int)): PersonAggregator2 = {
    this.mathSet += Person(element._1, element._2)
    this.totalGrade += element._2
    this.totalCount += 1
    if (element._2 == 0) this.zeroCount += 1

    this
  }

  /**
    * 聚合成绩、人数的方法
    *
    * @param that 另一个聚合器
    * @return this PersonAggregator
    */
  def ++=(that: PersonAggregator2): PersonAggregator2 = {
    this.mathSet ++= that.mathSet
    this.totalGrade += that.totalGrade
    this.totalCount += that.totalCount
    this.zeroCount += that.zeroCount

    this
  }

  /**
    * 计算平均值
    */
  def calcAVG(): Double = {
    totalGrade / totalCount.toDouble
  }

  override def toString: String =
    "PersonAggregator2{" +
      "mathSet=" + mathSet +
      ", avgGrade=" + calcAVG() +
      ", zeroCount=" + zeroCount +
      '}'

}

object PersonAggregator2 {

  def apply(): PersonAggregator2 =
    new PersonAggregator2(MyTreeSet[Person](), 0, 0 ,0)

}
  • 最后,在本地写个测试代码
import scala.collection.immutable.TreeSet

object Demo {

  def main(args: Array[String]): Unit = {

    val aggregator2 = PersonAggregator2()
    aggregator2 += (1, 80) += (2, 0) += (3, 0)
    println(s"aggregator2 = $aggregator2")
    
  }

}
  • 结果如下。Perfect! ^_^Just a piece of cake!
aggregator2 = PersonAggregator2{mathSet=TreeSet(Person{id=1, grade=80}, Person{id=3, grade=0}, Person{id=2, grade=0}), avgGrade=26.666666666666668, zeroCount=2}

2.3 单列多重聚合最终代码

  • Spark统计部分的主体代码如下
    val resultRDD = studentDF.rdd
      .map(row => {
        val id = row.getLong(row.fieldIndex("id"))
        val math = row.getInt(row.fieldIndex("math"))
        val year = row.getInt(row.fieldIndex("year"))

        (year, (id, math))
      })
      .aggregateByKey(PersonAggregator2())(
        (agg, v) => agg += v,
        (agg1, agg2) => agg1 ++= agg2
      )

3. 单列多重聚合(复杂示例)

3.1 新的需求

  • 添加对数学的均方根误差(RMSE)的统计

3.2 聚合算法分析

  • RMSE的计算公式: 1 m ∑ i = 1 m ( x i − x − ) 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i} - _x^{-})^2} m1i=1m(xix)2

  • 咋一看上去似乎不可能能够一次性统计完,因为似乎得先算出平均数,才能继续计算RMSE的值啊!你的思路或许是这样的:

    • 第一步,求均值、总数
    • 第二步,对所有值与均值的差的方求和,然后将和除以总数,再开方
  • 上面的逻辑没有问题,但是真的就不能一次完成聚合吗?

  • 让我们先尝试对聚合算法进行拆解(当然有的算法确实没法拆解),对RMSE的算法进行转换,过程如下:

    • 1 m ∑ i = 1 m ( x i − x − ) 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i} - _x^{-})^2} m1i=1m(xix)2
    • ⇒ t r a n s f o r m \xRightarrow{transform} transform
    • 1 m ∑ i = 1 m ( x i 2 − 2 x i x − + x − 2 ) \sqrt{\frac{1}{m}\sum_{i=1}^{m} (x_{i}^2 - 2{x_i}_x^{-} + {_x^{-}}^2)} m1i=1m(xi22xix+x2)
    • ⇒ t r a n s f o r m \xRightarrow{transform} transform
    • 1 m ∑ i = 1 m x i 2 − 2 x − m ∑ i = 1 m x i + x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - \frac{2_x^{-}}{m}\sum_{i=1}^{m} {x_i} + {_x^{-}}^2} m1i=1mxi2m2xi=1mxi+x2
    • ⇒ t r a n s f o r m \xRightarrow{transform} transform
    • 1 m ∑ i = 1 m x i 2 − 2 x − 2 + x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - 2{_x^{-}}^2 + {_x^{-}}^2} m1i=1mxi22x2+x2
    • ⇒ t r a n s f o r m \xRightarrow{transform} transform
    • 1 m ∑ i = 1 m x i 2 − x − 2 \sqrt{\frac{1}{m}\sum_{i=1}^{m} x_{i}^2 - {_x^{-}}^2} m1i=1mxi2x2
    • 最后,我们拆成了根号内的两部分,分别由以下参数组成:
      • m 代表 数据的总数量
      • ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 i=1mxi2 代表 对所有值的方求和
      • x − _x^{-} x 代表 所有值的均值(等于 ∑ i = 1 m x i \sum_{i=1}^{m} x_i i=1mxi除以m)
  • 现在来看,显然简单了,你只需要找到m、 ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 i=1mxi2 ∑ i = 1 m x i \sum_{i=1}^{m} x_i i=1mxi即可

3.3 代码编写示例

  • Person类不变
  • 编写聚合器Aggregator,记住关键点在于:
    • 分别聚合出m、 ∑ i = 1 m x i 2 \sum_{i=1}^{m} x_{i}^2 i=1mxi2 ∑ i = 1 m x i \sum_{i=1}^{m} x_i i=1mxi的值
    • 套用刚才最后得出的公式,计算出RMSE
/**
  * Description: 数学平均成绩,RMSE -> 聚合器
  *
  * @note {{{
  *      分数之和 -> totalGrade
  *      总人数 -> totalCount
  *      平均成绩 -> totalGrade / totalCount
  *      所有分数平方的和 -> sqrtSum
  *      (如果Long不够大,你可以换其他专用的数据类型,例如BigInt)
  * }}}
  *
  * Date: 2019/11/27 1:39
  * @author ALion
  */
class PersonAggregator3(var totalGrade: Long, var totalCount: Long, var powSum: Long) {


  /**
    * 聚合成绩、人数的方法
    *
    * @param that 另一个聚合器
    * @return this PersonAggregator
    */
  def ++(that: PersonAggregator3): PersonAggregator3 = {
    this.totalGrade += that.totalGrade
    this.totalCount += that.totalCount
    this.powSum += that.powSum

    new PersonAggregator3(totalGrade, totalCount, powSum)
  }

  /**
    * 计算平均值
    */
  def calcAVG(): Double = {
    totalGrade / totalCount.toDouble
  }

  /**
    * 根据化简后的公式计算 RMSE
    */
  def calcRMSE(): Double = {
    val avg = calcAVG()
    Math.sqrt(powSum / totalCount.toDouble - avg * avg)
  }

  // 懂lazy的话,就按下面的写法写
  //  lazy val avg: Double = totalGrade / totalCount.toDouble
  //
  //  lazy val rmse: Double = Math.sqrt(sqrtSum / totalCount.toDouble - avg * avg)

  override def toString: String =
    "PersonAggregator3{" +
      "avgGrade=" + calcAVG() +
      ", rmse=" + calcRMSE() +
      '}'

}

object PersonAggregator3 {

  def apply(math: Int): PersonAggregator3 =
    new PersonAggregator3(math, 1, math * math)

}
  • Spark统计部分的主体代码
val resultRDD = studentDF.rdd
  .map(row => {
    val math = row.getInt(row.fieldIndex("math"))
    val year = row.getInt(row.fieldIndex("year"))

	// 此处不用为每个元素生成一个大对象(集合等),无需使用aggregateByKey,你可以试着写一下:)
    (year, PersonAggregator3(math))
  }).reduceByKey(_ ++ _)

3.4 利用SparkSQL自定义聚合函数求解RMSE,见附录

4. 多列多重聚合

  • 噢???这个已经不用说怎么做了吧?
  • 领导:“对了,刚才的求数学前100名、均值、0分人数、RMSE的统计,再给我把语文、英语、物理……也整一个!”
    给我也整一个
    有多少整多少

5. 总结与整理

  • 编写聚合器Aggregator的目的:(不写聚合器类,也能完成需求,请看附录)
    • 更好的封装性,易于扩展
    • 防止与聚合无关的代码发生耦合,以做到"高内聚,低耦合"
    • 有利于编写更加复杂的聚合逻辑
    • 代码更优雅、简洁
  • 编写聚合器Aggregator的关键点:
    • 以需要聚合的点(最值、均值、计数等)为该聚合器的属性或者全局变量
    • 构建聚合算法(例如示例中的++,你也可以写成其他名字),编写你需要的聚合逻辑(除了一般的求均值、最值等聚合外,你还可以编写各种复杂的聚合需求!)
  • 编写聚合器之前,你可以先找找Java是否有可以直接使用的类,以用于聚合:
  • 如果一个业务首先想到的是需要写groupByKey算子解决,那么你可以尝试使用本篇提供的思路来解决问题。另外,想法不要太机械,这只是一个通用的方法示例,不同的业务有不同的需求,可以按照不同的方式实现,很多业务需求不用这样编写聚合器也能完成聚合任务!!!(你可以试试前面的“统计字段值的总数、去重后的总数”,有很多种方法实现快速聚合
  • 不是一定要根据key分组聚合的话,你还可以尝试直接使用treeReduce方法聚合
  • 不适用的场景:
    • 这种聚合方式只适用于多的数据聚合成少量的聚合(不然也不叫聚合了^_^哈哈!如果最后生成的数据量没有变化,无论怎么优化,也毫无意义,因为这些数据始终存在,必然占用空间)。例如,你要求排序,并保留所有数据的顺序,而不是取前n名。
    • 单次聚合分析必须要所有数据才能得出结果的业务是不适用的。(或者你可以考虑如何分化该聚合业务需求,很多业务并不是真的一定要得到所有数据后才能开始聚合)
    • 本身需要多次串行的聚合逻辑的业务不直接适用。(多想想再做决定,例如RMSE就是看起来必须先计算平均值再求结果,实际上转化计算公式后就不一样了)
  • 关于现在网上有很多人传“用随机数的方式进行两次shuffle解决数据倾斜的复杂聚合问题”的方案,这种方案确实是有不错的效果,不过就是变成了2次shuffle,把问题弄复杂了

6. 附录

  • 不用这里的聚合器,也是可以实现我们的需求的,示例如下:
// 针对前面求RMSE的业务
val resultRDD = studentDF.rdd
  .map { row =>
    val math = row.getInt(row.fieldIndex("math"))
    val year = row.getInt(row.fieldIndex("year"))

    (year, (math, 1, math * math))
  }.reduceByKey { case (t1, t2) =>
    (t1._1 + t2._1, t1._2 + t2._2, t1._3 + t2._3)
  }.mapValues { case (totalGrade, totalCount, powSum) =>
    val avg = totalGrade / totalCount.toDouble
    val rmse = Math.sqrt(powSum / totalCount.toDouble - avg * avg)
    (avg, rmse)
  }
  • 利用SparkSQL自定义聚合函数求解RMSE
/**
 * Description: 自定义求RMSE的聚合函数
 * 
 * @example {{{
 *   spark.udf.register("rmseUDAF", new MyRmseUDAF())
 *   spark.sql("SELECT rmseUDAF(math) FROM tb_person")
 * }}}
 *
 * @author ALion
 */
class MyRmseUDAF extends UserDefinedAggregateFunction{

  override def inputSchema: StructType = StructType(
    StructField("math", LongType) :: Nil
  )

  override def bufferSchema: StructType = StructType(
    StructField("totalGrade", LongType) ::
    StructField("totalCount", LongType) ::
    StructField("powSum", LongType) :: Nil
  )

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0L)
    buffer.update(1, 0L)
    buffer.update(2, 0L)
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // totalGrade
    buffer.update(0, buffer.getLong(0) + input.getLong(0))
    // totalCount
    buffer.update(1, buffer.getLong(1) + 1)
    // powSum
    buffer.update(2, buffer.getLong(2) + input.getLong(0) * input.getLong(0))
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // totalGrade
    buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
    // totalCount
    buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
    // powSum
    buffer1.update(2, buffer1.getLong(2) + buffer2.getLong(2))
  }

  override def evaluate(buffer: Row): Any = {
    val totalGrade = buffer.getLong(0)
    val totalCount = buffer.getLong(1).toDouble
    val powSum = buffer.getLong(2)
    val avg = totalGrade / totalCount
    // RMSE
    Math.sqrt(powSum / totalCount - avg * avg)
  }

}
  • MyTreeSet(简易实现,针对mutable.TreeSet封装)
import scala.collection.mutable

class MyTreeSet[A](firstNum: Int, elem: Seq[A])(implicit val ord: Ordering[A]) {

  val set: mutable.TreeSet[A] = mutable.TreeSet[A](elem: _*)

  def +=(elem: A): MyTreeSet[A] = {
    this add elem

    this
  }

  def add(elem: A): Unit = {
    set.add(elem)

    // 删除排在最后的多余元素
    check10Size()
  }

  def ++=(that: MyTreeSet[A]) : MyTreeSet[A] = {
    that.set.foreach(e => this add e)

    this
  }

  def check10Size(): Unit = {
    // 如果超过了firstNum个,就删除
    if (set.size > firstNum) {
      set -= set.last
    }
  }

  override def toString: String = set.toString
}

object MyTreeSet {

  def apply[A](elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](10, elem)
  
  def apply[A](firstNum: Int, elem: A*)(implicit ord: Ordering[A]): MyTreeSet[A] = new MyTreeSet[A](firstNum, elem)
  
}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值