spark-aggregate与treeAggregate的理解

spark-mllib中许多算法用到了treeAggregate这个方法,使用该方法而不是aggregate方法能够提升算法的性能。比如mllib中的GaussianMixture模型可以提升20%的性能,见treeAggregate

此前对这种聚合方式不是很了解,因此这里记录一下。

1. 一个例子
def main(args: Array[String]): Unit = {
  val spark = SparkSession
    .builder
    .appName(s"agg")
    .master("local")
    .getOrCreate()
  val sc = spark.sparkContext

  def seqOp(s1:Int, s2:Int):Int = {
    println("seq: "+s1+":"+s2)
    s1 + s2
  }

  def combOp(c1: Int, c2: Int): Int = {
    println("comb: "+c1+":"+c2)
    c1 + c2
  }

  val rdd = sc.parallelize(1 to 12).repartition(6)
  val res1 = rdd.aggregate(0)(seqOp, combOp)
// val res2 = rdd.treeAggregate(0)(seqOp, combOp)
  println(res1)
// println(res2)
}

aggregate:

seq: 0:6
seq: 6:12
comb: 0:18
seq: 0:1
seq: 1:7
comb: 18:8
seq: 0:2
seq: 2:8
comb: 26:10
seq: 0:3
seq: 3:9
comb: 36:12
seq: 0:4
seq: 4:10
comb: 48:14
seq: 0:5
seq: 5:11
comb: 62:16
78

treeAggregate:

seq: 0:6
seq: 6:12
seq: 0:1
seq: 1:7
seq: 0:2
seq: 2:8
seq: 0:3
seq: 3:9
seq: 0:4
seq: 4:10
seq: 0:5
seq: 5:11
[Stage 2:> (0 + 0) / 2]

comb: 18:10
comb: 28:14
comb: 8:12
comb: 20:16
comb: 42:36
78

2. Aggregate

treeAggregate是aggregate的一种特殊形式,因此了解treeAggregate首先需要了解aggregate的如何对数据做聚合操作。方法定义如下:

def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U

从aggregate方法的定义中,可以看到它需要传入三个参数:

  1. 聚合的初始值:zeroValue: U
  2. 对序列操作的函数:seqOp
  3. 聚合函数:combOp

aggregate函数将每个分区进行seqOp,且从zeroValue开始遍历分区里的所有元素。然后用combOp。从zeroValue开始遍历所有分区的结果。

注:每个partition的seqOp只应用一次zeroValue,最后的combOp也应用一次zeroValue。

用一张图来说明上面的计算过程:

aggregate

3. treeAggregate
def treeAggregate[U: ClassTag](zeroValue: U)(
  seqOp: (U, T) => U,
  combOp: (U, U) => U,
  depth: Int = 2): U

​ 与aggregate不同的是treeAggregate多了depth的参数,其他参数含义相同。aggregate在执行完SeqOp后会将计算结果拿到driver端使用CombOp遍历一次SeqOp计算的结果,最终得到聚合结果。而treeAggregate不会一次就Comb得到最终结果,SeqOp得到的结果也许很大,直接拉到driver可能会OutOfMemory,因此它会先把分区的结果做局部聚合(reduceByKey),如果分区数过多时会做分区合并,之后再把结果拿到driver端做reduce。

注:与aggregate不同的地方是:在每个分区,会做两次或者多次combOp,避免将所有局部的值传给driver端。另外,初始值zeroValue不会参与combOp。

具体可以参见源码:

  /**
   * Aggregates the elements of this RDD in a multi-level tree pattern.
   *
   * @param depth suggested depth of the tree (default: 2)
   * @see [[org.apache.spark.rdd.RDD#aggregate]]
   */
  def treeAggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U,
      combOp: (U, U) => U,
      depth: Int = 2): U = withScope {
    require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
    if (partitions.length == 0) {
      Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
    } else {
      val cleanSeqOp = context.clean(seqOp)
      val cleanCombOp = context.clean(combOp)
      val aggregatePartition =
        (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
      var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
      var numPartitions = partiallyAggregated.partitions.length
      val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
      // If creating an extra level doesn't help reduce
      // the wall-clock time, we stop tree aggregation.

      // Don't trigger TreeAggregation when it doesn't save wall-clock time
      while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
        numPartitions /= scale
        val curNumPartitions = numPartitions
        partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
          (i, iter) => iter.map((i % curNumPartitions, _))
        }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
      }
      partiallyAggregated.reduce(cleanCombOp)
    }
  }

还是用一张图来说明:
treeAggregate

参考:

https://www.cnblogs.com/drawwindows/p/5762392.html

http://blog.csdn.net/lookqlp/article/details/52121057

https://www.jianshu.com/p/27222830d21a

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值