Spark UDAF

package cn.spark.study.udf;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

public class StringCount extends UserDefinedAggregateFunction {

/**
 * 
 */
private static final long serialVersionUID = 1L;

/**
 * inputSchema 指的是,输入数据的类型
 */
@Override
public StructType inputSchema() {
    StructField[] fields= {DataTypes.createStructField("str", DataTypes.StringType, true)};
    StructType schema = DataTypes.createStructType(fields);
    return schema;
}
/**
 * bufferSchema 指的是,中间聚合时,所处理的数据的类型
 */
@Override
public StructType bufferSchema() {
    StructField[] fields= {DataTypes.createStructField("count", DataTypes.IntegerType, true)};
    StructType schema = DataTypes.createStructType(fields);
    return schema;
}

/**
* dataType ,指的是,函数返回值的类型
*/
@Override
public DataType dataType() {
return DataTypes.IntegerType;
}

@Override
public boolean deterministic() {
    return true;
}
 /**
  *  为每个分组的数据执行初始化操作
  */
@Override
public void initialize(MutableAggregationBuffer buffer) {
    buffer.update(0, 0);
}

/**
 *  指的是,每个分组,有新的值进来的时候,如何进行分组对应的聚合值的计算
 */
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
    Integer bf = buffer.<Integer>getAs(0);
    buffer.update(0, bf+1);
}



@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
    Integer bf1 = buffer1.<Integer>getAs(0);
    Integer bf2 = buffer2.<Integer>getAs(0);
    buffer1.update(0, bf1+bf2);
}

/**
 * 最后,指的是,一个分组的聚合值,,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
 */
@Override
public Object evaluate(Row buffer) {
    return buffer.<Integer>getAs(0);
}   

}

使用:

package cn.spark.study.udf;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

@SuppressWarnings(value={“unused”})
public class UdafSql {

/**
 * he following example registers a Scala closure as UDF:

sqlContext.udf.register(“myUDF”, (arg1: Int, arg2: String) => arg2 + arg1)
The following example registers a UDF in Java:

      sqlContext.udf().register("myUDF",
       new UDF2<Integer, String, String>() {
 @Override
  public String call(Integer arg1, String arg2) {
     return arg2 + arg1;
  }

}, DataTypes.StringType);
Or, to use Java 8 lambda syntax:

sqlContext.udf().register(“myUDF”,
(Integer arg1, String arg2) -> arg2 + arg1,
DataTypes.StringType);
* @param args
*/
public static void main(String[] args) {
firstUdf();
}

private static void firstUdf(){
    SparkConf conf = new SparkConf().setAppName("UdfSql").setMaster("local");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext sqlct = new SQLContext(jsc);
    String[] str= {"Hpf99","Leo","Marray","Jack","Tom","Tom","Tom","Leo","Leo","Marray","Marray","Jack"};
    List<String> lis = Arrays.asList(str);
    JavaRDD<String> strRdd = jsc.parallelize(lis);

    JavaRDD<Row> rowRdd = strRdd.mapPartitions(new FlatMapFunction<Iterator<String>, Row>() {

        /**
         * 
         */
        private static final long serialVersionUID = 1L;

        @Override
        public Iterable<Row> call(Iterator<String> t) throws Exception {
            List<Row> lis = new ArrayList<Row>();
            while(t.hasNext()){
                String next = t.next();
                Row create = RowFactory.create(next);
                lis.add(create);
            }
            return lis;
        }
    });

    StructField[] fields= {DataTypes.createStructField("name", DataTypes.StringType, true)};
    StructType schema = DataTypes.createStructType(fields);
    DataFrame rowDF = sqlct.createDataFrame(rowRdd, schema);

    rowDF.registerTempTable("names");

    sqlct.udf().register("strCount", new StringCount());

    DataFrame sql = sqlct.sql("SELECT name, strCount(name) as mycount,count(name) FROM names group by name");
    sql.show();

    jsc.close();
}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值