spark-mllib中许多算法用到了treeAggregate这个方法,使用该方法而不是aggregate方法能够提升算法的性能。比如mllib中的GaussianMixture模型可以提升20%的性能,见treeAggregate
此前对这种聚合方式不是很了解,因此这里记录一下。
1. 一个例子
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName(s"agg")
.master("local")
.getOrCreate()
val sc = spark.sparkContext
def seqOp(s1:Int, s2:Int):Int = {
println("seq: "+s1+":"+s2)
s1 + s2
}
def combOp(c1: Int, c2: Int): Int = {
println("comb: "+c1+":"+c2)
c1 + c2
}
val rdd = sc.parallelize(1 to 12).repartition(6)
val res1 = rdd.aggregate(0)(seqOp, combOp)
// val res2 = rdd.treeAggregate(0)(seqOp, combOp)
println(res1)
// println(res2)
}
aggregate:
seq: 0:6
seq: 6:12
comb: 0:18
seq: 0:1
seq: 1:7
comb: 18:8
seq: 0:2
seq: 2:8
comb: 26:10
seq: 0:3
seq: 3:9
comb: 36:12
seq: 0:4
seq: 4:10
comb: 48:14
seq: 0:5
seq: 5:11
comb: 62:16
78
treeAggregate:
seq: 0:6
seq: 6:12
seq: 0:1
seq: 1:7
seq: 0:2
seq: 2:8
seq: 0:3
seq: 3:9
seq: 0:4
seq: 4:10
seq: 0:5
seq: 5:11
[Stage 2:> (0 + 0) / 2]comb: 18:10
comb: 28:14
comb: 8:12
comb: 20:16
comb: 42:36
78
2. Aggregate
treeAggregate是aggregate的一种特殊形式,因此了解treeAggregate首先需要了解aggregate的如何对数据做聚合操作。方法定义如下:
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
从aggregate方法的定义中,可以看到它需要传入三个参数:
- 聚合的初始值:zeroValue: U
- 对序列操作的函数:seqOp
- 聚合函数:combOp
aggregate函数将每个分区进行seqOp,且从zeroValue开始遍历分区里的所有元素。然后用combOp。从zeroValue开始遍历所有分区的结果。
注:每个partition的seqOp只应用一次zeroValue,最后的combOp也应用一次zeroValue。
用一张图来说明上面的计算过程:
3. treeAggregate
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U
与aggregate不同的是treeAggregate多了depth的参数,其他参数含义相同。aggregate在执行完SeqOp后会将计算结果拿到driver端使用CombOp遍历一次SeqOp计算的结果,最终得到聚合结果。而treeAggregate不会一次就Comb得到最终结果,SeqOp得到的结果也许很大,直接拉到driver可能会OutOfMemory,因此它会先把分区的结果做局部聚合(reduceByKey),如果分区数过多时会做分区合并,之后再把结果拿到driver端做reduce。
注:与aggregate不同的地方是:在每个分区,会做两次或者多次combOp,避免将所有局部的值传给driver端。另外,初始值zeroValue不会参与combOp。
具体可以参见源码:
/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#aggregate]]
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == 0) {
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.length
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
// Don't trigger TreeAggregation when it doesn't save wall-clock time
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
}
}
还是用一张图来说明:
参考:
https://www.cnblogs.com/drawwindows/p/5762392.html