maven:
<!-- spark --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.10</artifactId> <version>1.6.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.10</artifactId> <version>1.6.0</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-hive_2.10</artifactId> <version>1.6.0</version> </dependency> <!-- google工具类 --> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>18.0</version> </dependency>
public class StringCount extends UserDefinedAggregateFunction { /** * inputSchema指的是输入的数据类型 * @return */ @Override public StructType inputSchema() { List<StructField> fields = Lists.newArrayList(); fields.add(DataTypes.createStructField("str", DataTypes.StringType,true)); return DataTypes.createStructType(fields); } /** * bufferSchema指的是 中间进行聚合时 所处理的数据类型 * @return */ @Override public StructType bufferSchema() { List<StructField> fields = Lists.newArrayList(); fields.add(DataTypes.createStructField("count", DataTypes.IntegerType,true)); return DataTypes.createStructType(fields); } /** * dataType指的是函数返回值的类型 * @return */ @Override public DataType dataType() { return DataTypes.IntegerType; } /** * 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。 * @return */ @Override public boolean deterministic() { return true; } /** * 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer * 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2, * 不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" * @param buffer */ @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0,0); } /** * 用输入数据input更新buffer值,类似于combineByKey * @param buffer * @param input */ @Override public void update(MutableAggregationBuffer buffer, Row input) { buffer.update(0,Integer.valueOf(buffer.getAs(0).toString())+1); } /** * 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节 * @param buffer1 * @param buffer2 */ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { buffer1.update(0,Integer.valueOf(buffer1.getAs(0).toString())+Integer.valueOf(buffer2.getAs(0).toString())); } /** * 计算并返回最终的聚合结果 * @param buffer * @return */ @Override public Object evaluate(Row buffer) { return buffer.getInt(0); } }
public class UDAF { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("UDAF").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(sc); List<String> nameList = Arrays.asList("xiaoming","xiaoming", "feifei","feifei","feifei", "katong"); //转换为javaRDD JavaRDD<String> nameRDD = sc.parallelize(nameList, 3); //转换为JavaRDD<Row> JavaRDD<Row> nameRowRDD = nameRDD.map(new Function<String, Row>() { public Row call(String name) throws Exception { return RowFactory.create(name); } }); List<StructField> fields = Lists.newArrayList(); fields.add(DataTypes.createStructField("name", DataTypes.StringType,true)); StructType structType = DataTypes.createStructType(fields); DataFrame namesDF = sqlContext.createDataFrame(nameRowRDD, structType); //注册names表 namesDF.registerTempTable("names"); sqlContext.udf().register("countString",new StringCount()); List<Row> rows = sqlContext.sql("select name,countString(name) from names group by name").javaRDD().collect(); for (Row row : rows) { System.out.println(row); } sc.close(); } }执行结果:
[feifei,3]
[xiaoming,2]
[katong,1]