Pyflink教程(三):自定义函数

该文章例子pyflink环境是apache-flink==1.13.6

Python 自定义函数是 PyFlink Table API 中最重要的功能之一,其允许用户在 PyFlink Table API 中使用 Python 语言开发的自定义函数,极大地拓宽了 Python Table API 的使用范围。

简单来说 就是有的业务逻辑和需求是sql语句满足不了或太麻烦的,需要用过函数来实现。

Python UDF

Python UDF,即 Python ScalarFunction,针对每一条输入数据,仅产生一条输出数据。

env_settings = EnvironmentSettings.new_instance().in_streaming_mode().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
table = t_env.from_elements([("hello&11", 1), ("world&22", 2), ("flink&33", 3)], ['a', 'b'])


#方式一:
#result_type 是输出类型,如果是多个返回值,则需写result_types
#同理也可以指定输入类型,input_type,多个返回值写input_types
@udf(result_type=DataTypes.STRING())
def sub_string(s: str, begin: int, end: int):
    return s[begin:end]

@udf(result_type=DataTypes.STRING())
def split_t(s):
    ab = s.split('&')
    return ab[0]

result = table.select(split_t(table.a).alias('a'))

# 方式二:
sub_string_lambda_fun = udf(lambda s, begin, end: s[begin:end], result_type=DataTypes.STRING())
result = table.select(sub_string_lambda_fun(table.a, 1, 3))

方式三:
# 继承ScalarFunction
# 实现eval方法,来实现方法
# open在初始化时执行一次,跟java的富函数一样,比如需要全局执行一次的(mysql连接等),可以放在open方法中执行, 
# 可以注册Metrics对象 https://nightlies.apache.org/flink/flink-docs-release-1.16/docs/dev/python/table/metrics/
class SubString(ScalarFunction):
    def open(self, function_context):
        #super().open(function_context)
        #self.counter = function_context.get_metric_group().counter("my_counter")
         pass

    def eval(self, s: str, begin: int, end: int):
        return s[begin:end]

sub_string = udf(SubString(), result_type=DataTypes.STRING())


result.execute().print()#直接打印
# result = result.to_pandas() ##这里可以转换成pandas
# 也可以用with遍历  
+----+--------------------------------+
| op |                              a |
+----+--------------------------------+
| +I |                          hello |
| +I |                          world |
| +I |                          flink |
+----+--------------------------------+

Python UDTF

Python UDTF,即 Python TableFunction,针对每一条输入数据,Python UDTF 可以产生 0 条、1 条或者多条输出数据,此外,一条输出数据可以包含多个列。比如以下示例,定义了一个名字为 split 的Python UDF,以指定字符串为分隔符,将输入字符串切分成两个字符串:

from pyflink.table.udf import udtf
from pyflink.table import DataTypes
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)
table = t_env.from_elements([("hello&11", 1), ("world&22", 2), ("flink&33", 3)], ['a', 'b'])
@udtf(result_types=[DataTypes.STRING(), DataTypes.STRING()])
def split(s: str, sep: str):
    splits = s.split(sep)
    yield splits[0], splits[1]
#合并两个结果集
#可以使用左、右和内等连接查询
#result = table.join_lateral(split(table.a, '&'))
result = table.left_outer_join_lateral(split(table.a, '&'))

result.execute().print()
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+
| op |                              a |                    b |                             f0 |                             f1 |
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+
| +I |                       hello&11 |                    1 |                          hello |                             11 |
| +I |                       world&22 |                    2 |                          world |                             22 |
| +I |                       flink&33 |                    3 |                          flink |                             33 |
+----+--------------------------------+----------------------+--------------------------------+--------------------------------+

Python UDAF

Python UDAF,即 Python AggregateFunction。Python UDAF 用来针对一组数据进行聚合运算,比如同一个 window 下的多条数据、或者同一个 key 下的多条数据等。针对同一组输入数据,Python AggregateFunction 产生一条输出数据。比如以下示例,定义了一个名字为 weighted_avg 的 Python UDAF:

from pyflink.common import Row
from pyflink.table import AggregateFunction, DataTypes, EnvironmentSettings, StreamTableEnvironment
from pyflink.table.udf import udaf


class WeightedAvg(AggregateFunction):
    ## ImperativeAggregateFunction 类需要实现的抽象类
    def create_accumulator(self):
        print("111")
        # Row(sum, count)
        return Row(0, 0)
    # AggregateFunction 类 需要实现的抽象类
    def get_value(self, retract) -> float:
        if retract[1] == 0:
            return 0
        else:
            return retract[0] / retract[1]
    ## ImperativeAggregateFunction 类需要实现的抽象类
    # 累加器方法
    def accumulate(self, accumulator, value, weight):
        print(value, weight)
        accumulator[0] += value * weight
        accumulator[1] += weight
    # 缩减方法,这个不需要必须实现
    def retract(self, accumulator: Row, value, weight):
        accumulator[0] -= value * weight
        accumulator[1] -= weight


weighted_avg = udaf(f=WeightedAvg(),
                    result_type=DataTypes.DOUBLE(),
                    accumulator_type=DataTypes.ROW([
                        DataTypes.FIELD("f0", DataTypes.BIGINT()),
                        DataTypes.FIELD("f1", DataTypes.BIGINT())]))
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)

t = t_env.from_elements([(1, 2, "Lee"), (3, 4, "Jay"), (5, 6, "Jay"), (7, 8, "Lee")],
                        ["value", "count", "name"])

result = t.group_by(t.name).select(weighted_avg(t.value, t.count).alias("avg"))
result.execute().print()
+----+--------------------------------+
| op |                            avg |
+----+--------------------------------+
| +I |                            5.8 |
| +I |                            4.2 |
+----+--------------------------------+

Python UDTAF

Python UDTAF,即 Python TableAggregateFunction。Python UDTAF 用来针对一组数据进行聚合运算,比如同一个 window 下的多条数据、或者同一个 key 下的多条数据等,与 Python UDAF 不同的是,针对同一组输入数据,Python UDTAF 可以产生 0 条、1 条、甚至多条输出数据。

from pyflink.common import Row
from pyflink.table import DataTypes, EnvironmentSettings, StreamTableEnvironment
from pyflink.table.udf import udtaf, TableAggregateFunction


class Top2(TableAggregateFunction):

    def create_accumulator(self):
        # 存储当前最大的两个值
        return [None, None]

    def accumulate(self, accumulator, input_row):
        if input_row[0] is not None:
            # 新的输入值最大
            if accumulator[0] is None or input_row[0] > accumulator[0]:
                accumulator[1] = accumulator[0]
                accumulator[0] = input_row[0]
            # 新的输入值次大
            elif accumulator[1] is None or input_row[0] > accumulator[1]:
                accumulator[1] = input_row[0]

    def emit_value(self, accumulator):
        yield Row(accumulator[0])
        if accumulator[1] is not None:
            yield Row(accumulator[1])


top2 = udtaf(f=Top2(),
             result_type=DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())]),
             accumulator_type=DataTypes.ARRAY(DataTypes.BIGINT()))
env_settings = EnvironmentSettings.new_instance().in_streaming_mode().use_blink_planner().build()
t_env = StreamTableEnvironment.create(environment_settings=env_settings)

t = t_env.from_elements([(1, 'Hi', 'Hello'),
                         (3, 'Hi', 'hi'),
                         (5, 'Hi2', 'hi'),
                         (2, 'Hi', 'Hello'),
                         (7, 'Hi', 'Hello')],
                        ['a', 'b', 'c'])

t_env.execute_sql("""
       CREATE TABLE my_sink (
         word VARCHAR,
         `sum` BIGINT
       ) WITH (
         'connector' = 'print'
       )
    """)

result = t.group_by(t.b).flat_aggregate(top2).select("b, a").execute_insert("my_sink")

# 1)等待作业执行结束,用于local执行,否则可能作业尚未执行结束,该脚本已退出,会导致minicluster过早退出
# 2)当作业通过detach模式往remote集群提交时,比如YARN/Standalone/K8s等,需要移除该方法
result.wait()
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值