flink(十五):udf自定义函数


2022-04-08

说明

本博客每周五更新一次。
自定义函数(UDF)是一种Flink 扩展开发机制,可在查询语句里实现自定义的功能逻辑。
自定义函数可用 JVM 语言(例如 Java 或 Scala)或 Python 实现,推荐java或scala。

分享

资料

种类

  • UDF按功能大致分为4类(也可以3类,聚合函数和表值聚合函数算一类),如下表
名称说明
标量函数把0到多个标量值映射成 1 个标量值
表值函数把0到多个标量值映射成多行数据
聚合函数把一行或多行数据聚合为1个值
表值聚合函数把一行或多行数据聚合为多行

标量函数

说明

  • 标量函数必须继承 org.apache.flink.table.functions.ScalarFunction 类,实现 eval 方法,java实例代码如下:

实例

//-------------- 实现标量函数 ----------------
import org.apache.flink.table.annotation.InputGroup;
import org.apache.flink.table.api.*;
import org.apache.flink.table.functions.ScalarFunction;
import static org.apache.flink.table.api.Expressions.*;

public static class HashFunction extends ScalarFunction {

  // 接受任意类型输入,返回 INT 型输出
  public int eval(@DataTypeHint(inputGroup = InputGroup.ANY) Object o) {
    return o.hashCode();
  }
}


// 调用自定义函数
TableEnvironment env = TableEnvironment.create(...);

//-------------- 方式1 不注册函数 ----------------
// 在 Table API 里不经注册直接“内联”调用函数
env.from("MyTable").select(call(HashFunction.class, $("myField")));

//-------------- 方式2 注册函数 ----------------
// 注册函数
env.createTemporarySystemFunction("HashFunction", HashFunction.class);

// 在 Table API 里调用注册好的函数
env.from("MyTable").select(call("HashFunction", $("myField")));

// 在 SQL 里调用注册好的函数
env.sqlQuery("SELECT HashFunction(myField) FROM MyTable");

表值函数

说明

  • 实现类 org.apache.flink.table.functions.TableFunction,通过实现多个名为 eval 的方法对求值方法进行重载。

实例

//-------------- 实现表值函数 ----------------
import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.annotation.FunctionHint;
import org.apache.flink.table.api.*;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;
import static org.apache.flink.table.api.Expressions.*;

@FunctionHint(output = @DataTypeHint("ROW<word STRING, length INT>"))
public static class SplitFunction extends TableFunction<Row> {

  public void eval(String str) {
    for (String s : str.split(" ")) {
      // use collect(...) to emit a row
      collect(Row.of(s, s.length()));
    }
  }
}

//-------------- 使用标量函数 ----------------
TableEnvironment env = TableEnvironment.create(...);

//-------------- 方式1:不注册使用 ----------------
// 在 Table API 里不经注册直接“内联”调用函数
env
  .from("MyTable")
  .joinLateral(call(SplitFunction.class, $("myField")))
  .select($("myField"), $("word"), $("length"));
env
  .from("MyTable")
  .leftOuterJoinLateral(call(SplitFunction.class, $("myField")))
  .select($("myField"), $("word"), $("length"));

// 在 Table API 里重命名函数字段
env
  .from("MyTable")
  .leftOuterJoinLateral(call(SplitFunction.class, $("myField")).as("newWord", "newLength"))
  .select($("myField"), $("newWord"), $("newLength"));

//-------------- 方式1:注册使用 ----------------
// 注册函数
env.createTemporarySystemFunction("SplitFunction", SplitFunction.class);

// 在 Table API 里调用注册好的函数
env
  .from("MyTable")
  .joinLateral(call("SplitFunction", $("myField")))
  .select($("myField"), $("word"), $("length"));
env
  .from("MyTable")
  .leftOuterJoinLateral(call("SplitFunction", $("myField")))
  .select($("myField"), $("word"), $("length"));

// 在 SQL 里调用注册好的函数
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable, LATERAL TABLE(SplitFunction(myField))");
env.sqlQuery(
  "SELECT myField, word, length " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) ON TRUE");

// 在 SQL 里重命名函数字段
env.sqlQuery(
  "SELECT myField, newWord, newLength " +
  "FROM MyTable " +
  "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) AS T(newWord, newLength) ON TRUE");

聚合函数

说明

  • 自定义聚合函数(UDAGG)是把一个表(一行或者多行,每行可以有一列或者多列)聚合成一个标量值。
    在这里插入图片描述

  • 如上图,有一个关于饮料的表,有三个字段id、name、price,有 5 行数据。假设需要找到所有饮料里最贵的饮料价格,即执行一个 max() 聚合。需要遍历所有5行数据,结果只有一个数值。

  • 自定义聚合函数是通过扩展 AggregateFunction 来实现的。AggregateFunction 需要 accumulator 定义数据结构,存储了聚合的中间结果。通过 AggregateFunction 的 createAccumulator() 方法创建一个空的 accumulator。对于每一行数据,会调用 accumulate() 方法来更新 accumulator。当所有的数据都处理完了之后,通过调用 getValue() 计算和返回最终结果。

  • 因此实现AggregateFunction 必须实现方法:createAccumulator()accumulate()getValue()

  • 某些场景下还需要实现其他方法。

    • retract() 在 bounded OVER 窗口中是必须实现的。
    • merge() 在许多批式聚合和会话以及滚动窗口聚合中是必须实现的。除此之外,这个方法对于优化也很多帮助。例如,两阶段聚合优化就需要所有的 AggregateFunction 都实现 merge 方法。
    • resetAccumulator() 在许多批式聚合中是必须实现的。

代码实例

//----------------创建数据对象 ----------------
/**
 * Accumulator for WeightedAvg.
 */
public static class WeightedAvgAccum {
    public long sum = 0;
    public int count = 0;
}

//-------------- 定义聚合函数 ----------------

/**
 * Weighted Average user-defined aggregate function.
 */
public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {

    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }

    @Override
    public Long getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return null;
        } else {
            return acc.sum / acc.count;
        }
    }

    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum += iValue * iWeight;
        acc.count += iWeight;
    }

    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum -= iValue * iWeight;
        acc.count -= iWeight;
    }

    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }

    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.count = 0;
        acc.sum = 0L;
    }
}

//-------------- 使用聚合函数 ----------------
// 注册函数
StreamTableEnvironment tEnv = ...
tEnv.registerFunction("wAvg", new WeightedAvg());

// 使用函数
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");


表值聚合函数

说明

  • 自定义表值聚合函数(UDTAGG)可以把一个表(一行或者多行,每行有一列或者多列)聚合成另一张表,结果中可以有多行多列。
    在这里插入图片描述

  • 如上图有一个表,3个字段分别为 id、name 和 price 共 5 行。假设需要找到价格最高的两个饮料,类似于 top2() 表值聚合函数。需要遍历所有 5 行数据,结果是有 2 行数据的一个表。

  • 自定义表值聚合函数通过扩展 TableAggregateFunction 类来实现的,具体执行过程如下。首先,需要一个 accumulator 负责存储聚合的中间结果。 通过调用 TableAggregateFunction 的 createAccumulator() 方法来一个空的 accumulator。对于每一行数据,调用 accumulate() 方法更新 accumulator。当所有数据都处理完之后,调用 emitValue() 方法计算和返回最终的结果。

  • 实现TableAggregateFunction 必须要实现的方法:createAccumulator()accumulate()

  • 某些场景下必须实现的方法:

    • retract() 在 bounded OVER 窗口中的聚合函数必须要实现。
    • merge() 在许多批式聚合和以及流式会话和滑动窗口聚合中是必须要实现的。
    • resetAccumulator() 在许多批式聚合中是必须要实现的。
    • emitValue() 在批式聚合以及窗口聚合中是必须要实现的。
  • emitUpdateWithRetract() 在 retract 模式下,可以提升人物效率,该方法负责发送被更新的值。

代码实例

  • 定义TableAggregateFunction 来计算给定列的最大的 2 个值,在 TableEnvironment 中注册函数,在 Table API 查询中使用函数(当前只在 Table API 中支持 TableAggregateFunction)。
//----------------创建数据对象 ----------------
/**
 * Accumulator for Top2.
 */
public class Top2Accum {
    public Integer first;
    public Integer second;
}

//-------------- 定义聚合函数 ----------------
/**
 * The top2 user-defined table aggregate function.
 */
public static class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accum> {

    @Override
    public Top2Accum createAccumulator() {
        Top2Accum acc = new Top2Accum();
        acc.first = Integer.MIN_VALUE;
        acc.second = Integer.MIN_VALUE;
        return acc;
    }


    public void accumulate(Top2Accum acc, Integer v) {
        if (v > acc.first) {
            acc.second = acc.first;
            acc.first = v;
        } else if (v > acc.second) {
            acc.second = v;
        }
    }

    public void merge(Top2Accum acc, java.lang.Iterable<Top2Accum> iterable) {
        for (Top2Accum otherAcc : iterable) {
            accumulate(acc, otherAcc.first);
            accumulate(acc, otherAcc.second);
        }
    }

    public void emitValue(Top2Accum acc, Collector<Tuple2<Integer, Integer>> out) {
        // emit the value and rank
        if (acc.first != Integer.MIN_VALUE) {
            out.collect(Tuple2.of(acc.first, 1));
        }
        if (acc.second != Integer.MIN_VALUE) {
            out.collect(Tuple2.of(acc.second, 2));
        }
    }
}

//-------------- 使用聚合函数 ----------------
// 注册函数
StreamTableEnvironment tEnv = ...
tEnv.registerFunction("top2", new Top2());

// 初始化表
Table tab = ...;

// 使用函数
tab.groupBy("key")
    .flatAggregate("top2(a) as (v, rank)")
    .select("key, v, rank");
  • 下面例子使用 emitUpdateWithRetract 方法来只发送更新的数据。为了只发送更新的结果,accumulator 保存上一次的最大2个值,也保存了当前最大2个值。
  • 注意:如果 TopN 中的 n 非常大,这种既保存上次的结果,也保存当前的结果的方式不太高效。一种解决这种问题的方式是把输入数据直接存储到 accumulator 中,然后在调用 emitUpdateWithRetract 方法时再进行计算。
//----------------创建数据对象 ----------------
/**
 * Accumulator for Top2.
 */
public class Top2Accum {
    public Integer first;
    public Integer second;
    public Integer oldFirst;
    public Integer oldSecond;
}

//-------------- 定义聚合函数 ----------------
/**
 * The top2 user-defined table aggregate function.
 */
public static class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accum> {

    @Override
    public Top2Accum createAccumulator() {
        Top2Accum acc = new Top2Accum();
        acc.first = Integer.MIN_VALUE;
        acc.second = Integer.MIN_VALUE;
        acc.oldFirst = Integer.MIN_VALUE;
        acc.oldSecond = Integer.MIN_VALUE;
        return acc;
    }

    public void accumulate(Top2Accum acc, Integer v) {
        if (v > acc.first) {
            acc.second = acc.first;
            acc.first = v;
        } else if (v > acc.second) {
            acc.second = v;
        }
    }

    public void emitUpdateWithRetract(Top2Accum acc, RetractableCollector<Tuple2<Integer, Integer>> out) {
        if (!acc.first.equals(acc.oldFirst)) {
            // if there is an update, retract old value then emit new value.
            if (acc.oldFirst != Integer.MIN_VALUE) {
                out.retract(Tuple2.of(acc.oldFirst, 1));
            }
            out.collect(Tuple2.of(acc.first, 1));
            acc.oldFirst = acc.first;
        }

        if (!acc.second.equals(acc.oldSecond)) {
            // if there is an update, retract old value then emit new value.
            if (acc.oldSecond != Integer.MIN_VALUE) {
                out.retract(Tuple2.of(acc.oldSecond, 2));
            }
            out.collect(Tuple2.of(acc.second, 2));
            acc.oldSecond = acc.second;
        }
    }
}

//-------------- 使用聚合函数 ----------------s
// 注册函数
StreamTableEnvironment tEnv = ...
tEnv.registerFunction("top2", new Top2());

// 初始化表
Table tab = ...;

// 使用函数
tab.groupBy("key")
    .flatAggregate("top2(a) as (v, rank)")
    .select("key, v, rank");

总结

  • 个人感觉UDF本质是抽象类的实现,扩展了Flink计算能力。
  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值