PySpark调用Java UDAF

Use Java UDAF in PySpark

Java


import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.*;

/**
 * @description: UDAF
 * @author: Mr.杜子腾
 * spark : 2.3.1
 * python : 2.7
 **/
 
public class Summation extends UserDefinedAggregateFunction {

    /* 输入参数的数据类型定义 */
    @Override
    public StructType inputSchema() {
        StructType structType = new StructType();
        structType.add("price", DataTypes.DoubleType);
        structType.add("quantity", DataTypes.DoubleType);
        return structType;
    }

    /* 聚合的中间过程中产生的数据的数据类型定义 */
    @Override
    public StructType bufferSchema() {
        StructType total = new StructType().add("total", DataTypes.DoubleType);
        return total;
    }

    /* 聚合结果的数据类型定义 */
    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    /* 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的 */
    @Override
    public boolean deterministic() {
        return true;
    }

    /* 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
     * 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
     */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {

        buffer.update(0, 0.0);
    }

    /* 用输入数据input更新buffer值,类似于combineByKey */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {

        double sum = buffer.getDouble(0);
        double price = input.getDouble(0);
        double qty = input.getDouble(1);
        buffer.update(0, sum + (price + qty));
    }

    //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
    //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0));
    }

    //计算并返回最终的聚合结果
    @Override
    public Object evaluate(Row row) {
        return row.getDouble(0);
    }

}

##Python

# coding:utf-8

import pandas as pd
from pyspark import SparkConf,SparkContext
from pyspark.sql import SparkSession,SQLContext

def get_word(s):
    return str(s) + "udf"

conf = SparkConf()
conf.set("spark.shuffle.manager", "tungsten-sort") \
    .set("spark.shuffle.sort.bypassMergeThreshold", "600") \
    .set("spark.local.dir", "/data/spark_cache/") \
    .set("spark.shuffle.consolidateFiles", "true") \
    .set("spark.driver.maxResultSize", "8g") \
    .set("spark.shuffle.file.buffer", "64") \
    .set("spark.kryoserializer.buffer.max", "256m") \
    .set("spark.reducer.maxSizeInFlight", "64") \
    .set("spark.shuffle.io.maxRetries", "60") \
    .set("spark.shuffle.io.retryWait", "60") \
    .set("spark.streaming.blockInterval", "30") \
    .set("spark.sql.warehouse.dir", "spark-warehouse") \
    .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .set("spark.executor.cores", "10") \
    .set("spark.cores.max", "20") \
    .set("spark.executor.memory", '8g') \
    .set("spark.driver.memory", '8g')
conf.setAppName("TSpark")
conf.setMaster("local")


spark = SparkSession.builder.config(conf=conf).getOrCreate()
sc = spark.sparkContext

#######################spark udf #######################
# spark.udf.register("stringLengthInt", lambda x: get_word(x), StringType())
# df = spark.createDataFrame([(1, "a"), (22, "b"), (33, "a")], ["id", "name"])
# df.registerTempTable("df")
# sql = spark.sql("SELECT name, stringLengthInt(id) from df")
# sql.show()
#########################################################


#### ##################spark udaf #######################
rdd = sc.textFile("file:///home/XX/jar/test.json") # local execute
df = spark.read.json(rdd)
df.createOrReplaceTempView("people")
spark.udf.registerJavaUDAF("SUMPRICE","com.fire.test.Summation")
df.printSchema()
spark.sql("select SUMPRICE(RetailValue,Stock) from people").show()

##数据样例

{"Make":"Honda","Model":"Pilot","RetailValue":32145.0,"Stock":4.1}
{"Make":"Honda","Model":"Civic","RetailValue":19575.0,"Stock":11}
{"Make":"Honda","Model":"Ridgeline","RetailValue":42870.0,"Stock":2}
{"Make":"Jeep","Model":"Cherokee","RetailValue":23595.0,"Stock":13}
{"Make":"Jeep","Model":"Wrangler","RetailValue":27895.0,"Stock":4}
{"Make":"Volkswagen","Model":"Passat","RetailValue":22440.0,"Stock":2}```

----------------------------------
[1]: http://math.stackexchange.com/
[2]: https://github.com/jmcmanus/pagedown-extra "Pagedown Extra"
[3]: http://meta.math.stackexchange.com/questions/5020/mathjax-basic-tutorial-and-quick-reference
[4]: http://bramp.github.io/js-sequence-diagrams/
[5]: http://adrai.github.io/flowchart.js/
[6]: https://github.com/benweet/stackedit
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

懒脖积泥

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值