说在前面:reduce/aggregate 操作开销比较大,而treeReduce/treeAggregate 可以通过调整深度来控制每次reduce的规模。
treeReduce源码:
/**
* Reduces the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#reduce]]
*/
def treeReduce(f: (T, T) => T, depth: Int = 2): T = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
val cleanF = context.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext) {
Some(iter.reduceLeft(cleanF))
} else {
None
}
}
//返回新的RDD
val partiallyReduced: RDD[Option[T]] = mapPartitions(it => Iterator(reducePartition(it)))
//返回一个(Option[T], Option[T]) => Option[T]函数
val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
if (c.isDefined && x.isDefined) {
Some(cleanF(c.get, x.get))
} else if (c.isDefined) {
c
} else if (x.isDefined) {
x
} else {
None
}
}
partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth).getOrElse(throw new UnsupportedOperationException("empty collection"))
}
treeAggregate源码:
/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
* This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]].
*
* @param depth suggested depth of the tree (default: 2)
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope {
require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == 0) {
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.length
val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
// Don't trigger TreeAggregation when it doesn't save wall-clock time
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values
}
val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
partiallyAggregated.fold(copiedZeroValue)(cleanCombOp)
}
}
treeAggregate和aggregate 对比
object TreeAggregateTest extends App {
val sparkConf = new SparkConf().
setAppName("TreeAggregateTest")
.setMaster("local[6]")
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()
val value: RDD[Int] = spark.sparkContext.parallelize(List(1, 2, 3, 5, 8, 9), 3)
val treeAggregateResult: Int = value.treeAggregate(4)((a, b) => {
math.min(a, b)
}, (a, b) => {
println(a + "+" + b)
a + b
}, 2)
println("treeAggregateResult:" + treeAggregateResult)
val aggregateResult: Int = value.aggregate(4)((a, b) => {
math.min(a, b)
}, (a, b) => {
println(a + "+" + b)
a + b
})
println("aggregateResult:" + aggregateResult)
}
aggregate 相比treeAggregate 在多个分区合并结果也就调用combOp阶段会调用一次 初始值而treeAggregate不会调用。seqOp两者都会在每个分区去调用。
我的理解就是treeAggregate把要发送给diver端执行的combOp 当分区数量过多时按照 i % curNumPartitions 重新分区后 也就代表着根据depth
去执行几次combOp最后将结果返回给diver端再进行combOp,这样做能减轻diver端的压力,减少了OOM的风险。
在实际操作中使用treeAggregate能更加灵活的操作。spark2.3相比之前treeAggregate,2.3版本用foldByKey将分区结果合并。