自定义累加器
/**
* 自定义累加器需要继承AccumulatorV2<IN,OUT>类
* 并且要指定要累加的类型
*/
public class MyAccumulator extends AccumulatorV2<MyKey,MyKey> implements Serializable {
/**
* 该累加状态是在Driver端初始化
* 并且值也是保存在Driver端
*/
private MyKey info = new MyKey(0, 0);
public MyKey getInfo() {
return info;
}
public void setInfo(MyKey info) {
this.info = info;
}
/**
* 判断是否是初始化状态
* 直接与原始状态的值比较
* 该判断为自己定义的判断方式
* @return
*/
@Override
public boolean isZero() {
return info.getPersonAgeSum()==0 && info.getPersonNum()==0;
}
/**
* 为每个分区创建一个新的累加器
* @return
*/
@Override
public AccumulatorV2<MyKey, MyKey> copy() {
MyAccumulator myAccumulator = new MyAccumulator();
myAccumulator.info = this.info;
return myAccumulator;
}
/**
* 初始化不同的partition分区中的累加类型
*/
@Override
public void reset() {
info = new MyKey(0, 0);
}
/**
* 进行累加时以何种规则进行累加
* @param v 每条新进来的记录
*/
@Override
public void add(MyKey v) {
info.setPersonNum(info.getPersonNum() + v.getPersonNum());
info.setPersonAgeSum(info.getPersonAgeSum() + v.getPersonAgeSum());
}
/**
* 合并不同partition分区中accumulator中储存的状态值
* @param other 每个分区中的累加器
*/
@Override
public void merge(AccumulatorV2<MyKey, MyKey> other) {
MyKey value = other.value();
info.setPersonNum(info.getPersonNum()+value.getPersonNum());
info.setPersonAgeSum(info.getPersonAgeSum()+value.getPersonAgeSum());
}
/**
* 最后返回的累加完成的状态值
* @return
*/
@Override
public MyKey value() {
return info;
}
}
自定义key
public class MyKey implements Serializable {
private Integer personNum;
private Integer personAgeSum;
public MyKey() {
}
public MyKey(Integer personNum, Integer personAgeSum) {
this.personNum = personNum;
this.personAgeSum = personAgeSum;
}
public Integer getPersonNum() {
return personNum;
}
public void setPersonNum(Integer personNum) {
this.personNum = personNum;
}
public Integer getPersonAgeSum() {
return personAgeSum;
}
public void setPersonAgeSum(Integer personAgeSum) {
this.personAgeSum = personAgeSum;
}
@Override
public String toString() {
return "MyKey{" +
"personNum=" + personNum +
", personAgeSum=" + personAgeSum +
'}';
}
}
运行:
public class MyRun {
public static void main(String[] args) {
SparkConf conf = new SparkConf();
conf.setAppName("testAccumulator");
conf.setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
MyAccumulator acc = new MyAccumulator();
sc.sc().register(acc,"PersonInfoAccumulator");
JavaRDD<String> rdd = sc.parallelize(Arrays.asList(
"zhangsan 1", "lisi 2", "wangwu 3", "zhaoliu 4", "tianqi 5", "zhengba 6"
));
rdd.map(new Function<String, String>() {
@Override
public String call(String v1) throws Exception {
acc.add(new MyKey(1,Integer.parseInt(v1.split(" ")[1])));
return v1;
}
}).collect();
System.out.println("value = "+acc.value());
}
}
结果:
value = MyKey{personNum=6, personAgeSum=21}