# sparkRDD算子系列-treeAggregate算子

1.函数定义:

def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope

2.函数解释:

aggregate需要三个参数（初始值zeroValue，函数seqOp和函数combOp），返回值类型U同初始值zeroValue一样。
处理过程：

1.seqOp函数是将分区中的每个元素和初始值进行计算，返回

2. combOp函数是将每个分区计算出来的结果在进行计算，然后返回

3. depth是用来计算scale，val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt,         2),numPartitions为分区数量

3.demo案例

package com.lilei.rdd;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;

import java.util.Arrays;
import java.util.List;

public class TreeAggregateTest {

public static void main(String[] args) {
SparkConf sparkConf = new SparkConf().setAppName("treeaggregateTest")
.setMaster("local[*]");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
List<Integer> data = Arrays.asList(5, 1, 1, 4, 4, 2, 2);
JavaRDD<Integer> javaRDD = javaSparkContext.parallelize(data,3);
//转化操作
JavaRDD<String> javaRDD1 = javaRDD.map(new Function<Integer, String>() {
@Override
public String call(Integer v1) throws Exception {
return Integer.toString(v1);
}
});

String result1 = javaRDD1.treeAggregate("0", new Function2<String, String, String>() {
@Override
public String call(String v1, String v2) throws Exception {
System.out.println(v1 + "=seq=" + v2);
return v1 + "=seq=" + v2;
}
}, new Function2<String, String, String>() {
@Override
public String call(String v1, String v2) throws Exception {
System.out.println(v1 + "<=comb=>" + v2);
return v1 + "<=comb=>" + v2;
}
});
System.out.println(result1);

}
}


4.结果显示

5.rdd优点

当rdd中的分区数量比较多的时候，如果用treeAggregate算子的话每个分区又会重新做一次reduce计算

源码:

// Don't trigger TreeAggregation when it doesn't save wall-clock time
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)