Spark 2.X 中的累加器和 Spark 1.X中有着很大不同,下面将实现的功能是:
将一个集合,集合中含有字母 "A","B","A","D","E","D","G","H","I","A","B","I","G","D","I" ,目的是统计这个集合中字母的个数,并将其拼接成一个字符串:A=3|B=2|C=0|D=3|E=1|F=0|G=2|H=1|I=3
1、先准备拆分字符串的工具类StringUtils,主要用到了两个方法:
/** * 从拼接的字符串中提取字段 * @param str 字符串 * @param delimiter 分隔符 * @param field 字段 * @return 字段值 */ public static String getFieldFromConcatString(String str, String delimiter, String field) { try { String[] fields = str.split(delimiter); for(String concatField : fields) { if(concatField.split("=").length == 2) { String fieldName = concatField.split("=")[0]; String fieldValue = concatField.split("=")[1]; if(fieldName.equals(field)) { return fieldValue; } } } } catch (Exception e) { e.printStackTrace(); } return null; } /** * 从拼接的字符串中给字段设置值 * @param str 字符串 * @param delimiter 分隔符 * @param field 字段名 * @param newFieldValue 新的field值 * @return 字段值 */ public static String setFieldInConcatString(String str, String delimiter, String field, String newFieldValue) { String[] fields = str.split(delimiter); for(int i = 0; i < fields.length; i++) { String fieldName = fields[i].split("=")[0]; if(fieldName.equals(field)) { String concatField = fieldName + "=" + newFieldValue; fields[i] = concatField; break; } } StringBuffer buffer = new StringBuffer(""); for(int i = 0; i < fields.length; i++) { buffer.append(fields[i]); if(i < fields.length - 1) { buffer.append("|"); } } return buffer.toString(); }
2、实现自定义的AccumulatorV2.
import org.apache.spark.util.AccumulatorV2; public class countAcc extends AccumulatorV2<String, String>{ //定义要拼接成的字符串的格式 String str="A=0|B=0|C=0|D=0|E=0|F=0|G=0|H=0|I=0"; //Returns if this accumulator is zero value or not. // e.g. for a counter accumulator, 0 is zero value; for a list accumulator, Nil is zero value. //如果这个累加器返回值为0,可以设置为false,否则会报错 java.lang.AssertionError: assertion failed: copyAndReset must return a zero value copy public boolean isZero() { return str=="A=0|B=0|C=0|D=0|E=0|F=0|G=0|H=0|I=0"; //return true; } //拷贝这个累加器 public AccumulatorV2<String, String> copy() { countAcc newAccumulator = new countAcc(); newAccumulator.str = this.str; return newAccumulator; } //Resets this accumulator, which is zero value. i.e. call isZero must return true. public void reset() { str="A=0|B=0|C=0|D=0|E=0|F=0|G=0|H=0|I=0"; } //Takes the inputs and accumulates. //参数v是每次遍历RDD传进来的值,相当于每次传进来一个字母 //在这里对字母进行拆分统计,从str字符串中找到传进来(v)的字母,找到它的value值加上1,比如: //str="A=0|B=0|C=0|D=0|E=0|F=0|G=0|H=0|I=0",如果传进来的是字母A,那么将得到的新字符串为str="A=1|B=0|C=0|D=0|E=0|F=0|G=0|H=0|I=0" //这里因为有三个分区,将会得到 // str1="A=2|B=1|C=0|D=1|E=1|F=0|G=0|H=0|I=0", str2="A=1|B=0|C=0|D=1|E=0|F=0|G=1|H=1|I=1", str3="A=0|B=1|C=0|D=1|E=0|F=0|G=1|H=0|I=2" //这三个中间结果会在下面的方法Merge中进行合并 public void add(String v) { String oldValues = StringUtils.getFieldFromConcatString(str, "\\|", v); int newValues = Integer.valueOf(oldValues) + 1; String newString = StringUtils.setFieldInConcatString(str, "\\|", v, String.valueOf(newValues)); str = newString; } //Merges another same-type accumulator into this one and update its state, i.e. this should be merge-in-place. //这里边other代表的是另外的相同类型的累加器,这里进行的是每个累加器的合并,相当于对上面三个str进行合并 public void merge(AccumulatorV2<String, String> other) { countAcc o =(countAcc)other; String[] words = str.split("\\|"); String[] owords = o.str.split("\\|"); for (int i = 0; i < words.length; i++) { for (int j = 0; j < owords.length; j++) { if (words[i].split("=")[0].equals(owords[j].split("=")[0])){ int value = Integer.valueOf(words[i].split("=")[1]) +Integer.valueOf(owords[j].split("=")[1]); String ns = StringUtils.setFieldInConcatString(str, "\\|", owords[j].split("=")[0], String.valueOf(value)); //每次合并完,更新str str = ns; } } } } //Defines the current value of this accumulator //当前的accumulator的值 public String value() { return str; } }
3、测试类:
import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.VoidFunction; import java.util.Arrays; import java.util.List; public class countAccumulator { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("AccumulatorTest").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); List<String> list = Arrays.asList("A","B","A","D","E","D","G","H","I","A","B","I","G","D","I"); final JavaRDD<String> javaRDD = sc.parallelize(list, 3).cache(); final countAcc sa = new countAcc(); sc.sc().register(sa,"sa"); javaRDD.foreach(new VoidFunction<String>() { public void call(String s) throws Exception { sa.add(s); } }); System.out.println(sa.value()); } }可以得到结果:A=3|B=2|C=0|D=3|E=1|F=0|G=2|H=1|I=3