spark 函数 UDF和UDAF

spark函数功能丰富,主要有内置函数、UDF、UDAF等(UDAF用于聚合agg,同时对多行操作返回单个聚合值)

内置函数类型较多,见官网Built-in Functions - Spark 3.0.1 Documentation (apache.org)

UDF函数
使用步骤:定义UserDefinedFunction(重载call方法)、注册spark.udf().register(" ", userDefinedFunction)、使用(map或agg或withColunm)。

import org.apache.spark.sql.*;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import static org.apache.spark.sql.functions.udf;
import org.apache.spark.sql.types.DataTypes;

SparkSession spark = SparkSession
  .builder()
  .appName("Java Spark SQL UDF scalar example")
  .getOrCreate();

// Define and register a zero-argument non-deterministic UDF
// UDF is deterministic by default, i.e. produces the same result for the same input.
UserDefinedFunction random = udf(
  () -> Math.random(), DataTypes.DoubleType
);
random.asNondeterministic();
spark.udf().register("random", random);
spark.sql("SELECT random()").show();
// +-------+
// |UDF()  |
// +-------+
// |xxxxxxx|
// +-------+

// Define and register a one-argument UDF
spark.udf().register("plusOne", new UDF1<Integer, Integer>() {
  @Override
  public Integer call(Integer x) {
    return x + 1;
  }
}, DataTypes.IntegerType);
spark.sql("SELECT plusOne(5)").show();
// +----------+
// |plusOne(5)|
// +----------+
// |         6|
// +----------+

// Define and register a two-argument UDF
UserDefinedFunction strLen = udf(
  (String s, Integer x) -> s.length() + x, DataTypes.IntegerType
);
spark.udf().register("strLen", strLen);
spark.sql("SELECT strLen('test', 1)").show();
// +------------+
// |UDF(test, 1)|
// +------------+
// |           5|
// +------------+

// UDF in a WHERE clause
spark.udf().register("oneArgFilter", new UDF1<Long, Boolean>() {
  @Override
  public Boolean call(Long x) {
    return  x > 5;
  }
}, DataTypes.BooleanType);
spark.range(1, 10).createOrReplaceTempView("test");
spark.sql("SELECT * FROM test WHERE oneArgFilter(id)").show();
// +---+
// | id|
// +---+
// |  6|
// |  7|
// |  8|
// |  9|
// +---+

UDAF函数
Aggregator

import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import org.apache.spark.sql.expressions.Aggregator;

public static class Employee implements Serializable {
  private String name;
  private long salary;

  // Constructors, getters, setters...

}

public static class Average implements Serializable  {
  private long sum;
  private long count;

  // Constructors, getters, setters...

}

public static class MyAverage extends Aggregator<Employee, Average, Double> {
  // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  public Average zero() {
    return new Average(0L, 0L);
  }
  // Combine two values to produce a new value. For performance, the function may modify `buffer`
  // and return it instead of constructing a new object
  public Average reduce(Average buffer, Employee employee) {
    long newSum = buffer.getSum() + employee.getSalary();
    long newCount = buffer.getCount() + 1;
    buffer.setSum(newSum);
    buffer.setCount(newCount);
    return buffer;
  }
  // Merge two intermediate values
  public Average merge(Average b1, Average b2) {
    long mergedSum = b1.getSum() + b2.getSum();
    long mergedCount = b1.getCount() + b2.getCount();
    b1.setSum(mergedSum);
    b1.setCount(mergedCount);
    return b1;
  }
  // Transform the output of the reduction
  public Double finish(Average reduction) {
    return ((double) reduction.getSum()) / reduction.getCount();
  }
  // Specifies the Encoder for the intermediate value type
  public Encoder<Average> bufferEncoder() {
    return Encoders.bean(Average.class);
  }
  // Specifies the Encoder for the final output value type
  public Encoder<Double> outputEncoder() {
    return Encoders.DOUBLE();
  }
}

Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
String path = "examples/src/main/resources/employees.json";
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

MyAverage myAverage = new MyAverage();
// Convert the function to a `TypedColumn` and give it a name
TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
Dataset<Double> result = ds.select(averageSalary);
result.show();
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+


// Register the function to access it
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");

UDAF函数
UserDefinedAggregateFunction

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;

public class MyAverage extends UserDefinedAggregateFunction {

    /**
     * 计算平均score,输入的应该是score这一列数据
     * StructField定义了列字段的名称score_column,字段的类型Double
     * @return StructType
     */
    @Override
    public StructType inputSchema() {
        return new StructType()
                .add("score_column", DataTypes.DoubleType);
    }

    /**
     * 缓存Schema,存储中间计算结果,
     * 比如计算平均score,需要计算score的总和和score的个数,然后average(score)=sum(score)/count(score)
     * 所以这里定义了StructType类型:两个StructField字段:sum和count
     * @return StructType
     */
    @Override
    public StructType bufferSchema() {
        return new StructType()
                .add("sum", DataTypes.DoubleType)
                .add("count", DataTypes.LongType);
    }

    /**
     * 自定义集成算子最终返回的数据类型
     * 也就是average(score)的类型,所以是Double
     * @return DataType 返回数据类型
     */
    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }


    /**
     * 数据一致性检验:对于同样的输入,输出是一样的
     * @return Boolean true 同样的输入,输出也一样
     */
    @Override
    public boolean deterministic() {
        return true;
    }


    /**
     * 初始化缓存sum和count
     * sum=0.0,count=0
     * @param mutableAggregationBuffer 中间数据
     */
    @Override
    public void initialize(MutableAggregationBuffer mutableAggregationBuffer) {
        mutableAggregationBuffer.update(0, 0.0);
        mutableAggregationBuffer.update(1, 0L);

    }

    /**
     * 每次计算更新缓存
     * @param mutableAggregationBuffer 缓存数据
     * @param row 输入数据score
     */
    @Override
    public void update(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
        mutableAggregationBuffer.update(0, mutableAggregationBuffer.getDouble(0) + row.getDouble(0));
        mutableAggregationBuffer.update(1, mutableAggregationBuffer.getLong(1) + 1);
    }

    /**
     * 将更新后的buffer存储到缓存
     * @param mutableAggregationBuffer 缓存
     * @param row 更新后的buffer
     */
    @Override
    public void merge(MutableAggregationBuffer mutableAggregationBuffer, Row row) {
        mutableAggregationBuffer.update(0, mutableAggregationBuffer.getDouble(0) + row.getDouble(0));
        mutableAggregationBuffer.update(1, mutableAggregationBuffer.getLong(1) + row.getLong(1));

    }


    /**
     * 计算最终的结果:average(score)=sum(score)/count(score)
     * @param row
     * @return
     */
    @Override
    public Object evaluate(Row row) {
        return row.getDouble(0)/row.getLong(1);
    }
}



Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
String path = "examples/src/main/resources/employees.json";
Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+

MyAverage myAverage = new MyAverage();
Dataset<Row> result = ds.groupBy("name").agg(myAverage.apply(ds.col("salary")).as("score_column"));
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值