【大数据】Presto开发自定义聚合函数

Presto 在交互式查询任务中担当着重要的职责。随着越来越多的人开始使用 SQL 在 Presto 上分析数据,我们发现需要将一些业务逻辑开发成类似 Hive 中的 UDF,提高 SQL 使用人员的效率,同时也保证 Hive 和 Presto 环境中的 UDF 统一。

1、Presto函数介绍

在此之前先简单介绍下UDF和UDAF,UDF叫做用户自定义函数,而UDAF叫做用户自定义聚合函数,区别就在于UDF不会保存状态,一行输入一行输出,而UDAF会涉及到状态的保存,通过聚合多个节点的数据来转换为最终的输出结果。

在 Presto 中,函数大体分为三种:scalar,aggregation 和 window 类型。分别如下:

1)scalar标量函数,简单来说就是 Java 中的一个静态方法,本身没有任何状态(不保存数据,一行输入一行输出)。

2)aggregation累积状态的函数,或聚集函数,如count,avg。如果只是单节点,单机状态可以直接用一个变量存储即可,但是presto是分布式计算引擎,状态数据会在多个节点之间传输,因此状态数据需要被序列化成 Presto 的内部格式才可以被传输。简单来说Aggregation对应于多行输入一行输出。

3)window 窗口函数,窗口函数在查询结果的行上进行计算,执行顺序在HAVING子句之后,ORDER BY子句之前。在 Presto SQL 中,窗口函数的语法形式如下:

windowFunction(arg1,....argn) OVER([PARTITION BY<...>] [ORDER BY<...>] [RANGE|ROWS 
BETWEEN AND])

窗口函数语法由关键字OVER触发,且包含三个子句:
PARTITION BY: 指定输入行分区的规则,类似于聚合函数的GROUP BY子句,不同分区里的计算互不干扰(窗口函数的计算是并发进行的,并发数和partition数量一致),缺省时将所有数据行视为一个分区
ORDER BY: 决定了窗口函数处理输入行的顺序
RANGE|ROWS BETWEEN AND: 指定窗口边界,不常用,缺省时的窗口为当前行所在的分区第一行到当前行。

2、自定义函数

官方文档地址:https://prestodb.io/docs/current/develop/functions.html

2.1自定义Scalar函数的实现

2.1.1定义一个java类
1)用 @ScalarFunction 的 Annotation 标记实现业务逻辑的静态方法。

2)用 @Description 描述函数的作用,这里的内容会在 SHOW FUNCTIONS 中显示。

3)用@SqlType 标记函数的返回值类型,如返回字符串,因此是 StandardTypes.VARCHAR。

4)Java 方法的返回值必须使用 Presto 内部的序列化方式,因此字符串类型必须返回 Slice, 使用 Slices.utf8Slice 方法可以方便的将 String 类型转换成 Slice 类型

public class ExampleStringFunction
{
    @ScalarFunction("lowercaser")
    @Description("converts the string to alternating case")
    @SqlType(StandardTypes.VARCHAR)
    public static Slice lowercaser(@SqlType(StandardTypes.VARCHAR) Slice slice)
    {
        String argument = slice.toStringUtf8();
        return Slices.utf8Slice(argument.toLowerCase());
    }
}

2.2 自定义Aggregation函数

2.2.1实现原理步骤
Presto 把 Aggregation 函数分解成三个步骤执行:

1、input(state, data): 针对每条数据,执行 input 函数。这个过程是并行执行的,因此在每个有数据的节点都会执行,最终得到多个累积的状态数据。

2、combine(state1, state2):将所有节点的状态数据聚合起来,多次执行,直至所有状态数据被聚合成一个最终状态,也就是 Aggregation 函数的输出结果。

3、output(final_state, out):最终输出结果到一个 BlockBuilder。

2.2.2 具体代码实现过程
1、定义一个 Java 类,使用 @AggregationFunction 标记为 Aggregation 函数

2、使用 @InputFunction、 @CombineFunction、@OutputFunction 分别标记计算函数、合并结果函数和最终输出函数在 Plugin 处注册 Aggregation 函数

3、一个继承AccumulatorState的State接口,get和set方法

4、并使用 @AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 类信息(stateFactoryClass指定)。自己写一个序列化类和一个工厂类。(复杂数据类型需要:自定义类保存状态、Map、List等)

简单类型Aggregation

对于简单数据类型的聚合函数编写比较简单,实现一个包含input、combine、output的aggregation和一个状态设定接口State提供get、set方法即可,不用去关心序列化和状态保存问题。
Aggregation:

@AggregationFunction("avg")
public final class IntervalYearToMonthAverageAggregation
{
    private IntervalYearToMonthAverageAggregation() {}

    @InputFunction
    public static void input(LongAndDoubleState state, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long value)
    {
        state.setLong(state.getLong() + 1);
        state.setDouble(state.getDouble() + value);
    }

    @CombineFunction
    public static void combine(LongAndDoubleState state, LongAndDoubleState otherState)
    {
        state.setLong(state.getLong() + otherState.getLong());
        state.setDouble(state.getDouble() + otherState.getDouble());
    }

    @OutputFunction(StandardTypes.INTERVAL_YEAR_TO_MONTH)
    public static void output(LongAndDoubleState state, BlockBuilder out)
    {
        long count = state.getLong();
        if (count == 0) {
            out.appendNull();
        }
        else {
            double value = state.getDouble();
            INTERVAL_YEAR_MONTH.writeLong(out, round(value / count));
        }
    }
}

LongAndDoubleState :写一个接口实现继承自AccumulatorState类,提供get、set方法即可。

public interface LongAndDoubleState
        extends AccumulatorState
{
    long getLong();

    void setLong(long value);

    double getDouble();

    void setDouble(double value);
}
复杂类型Aggregation

对于复杂数据类型则需要提供序列化机制,你要序列化那些东西都是由你来制指定的。在AccumulatorState的接口上用注解指定@AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 类信息(stateFactoryClass指定),自定义一个序列化器和序列化工厂类,实现类的序列化和反序列化。

Aggregation类: 这个类实现比较简单,和简单数据类型的实现一样,input、combine、output。

@AggregationFunction("presto_collect")
public class CollectListAggregation {

   @InputFunction
    public static void input(@AggregationState CollectState state, @SqlType(StandardTypes.VARCHAR) Slice id,@SqlType(StandardTypes.VARCHAR) Slice key) {
        try {
            CollectListStats stats = state.get();
            if (stats == null) {
                stats = new CollectListStats();
                state.set(stats);
            }
            int inputId = Integer.parseInt(id.toStringUtf8());
            String inputKey = key.toStringUtf8();
            stats.addCollectList(inputId,inputKey, 1);
       } catch (Exception e) {
            throw new RuntimeException(e+" ---------  input err");
        }
    }

    @CombineFunction
    public static void combine(@AggregationState CollectState state, CollectState otherState) {
        try {
            CollectListStats collectListStats = state.get();
            CollectListStats oCollectListStats = otherState.get();
            if(collectListStats == null) {
                state.set(oCollectListStats);
            } else {
                collectListStats.mergeWith(oCollectListStats);
            }
        }catch (Exception e) {
            throw new RuntimeException(e+" --------- combine err");
        }
    }

    @OutputFunction(StandardTypes.*VARCHAR*)
    public static void output(@AggregationState CollectState state, BlockBuilder out) {
        try {
            CollectListStats stats = state.get();
            if (stats == null) {
                out.appendNull();
                return;
            }
            // 统计
            Slice result = stats.getCollectResult();
            if (result == null) {
                out.appendNull();
            } else {
                VarcharType.VARCHAR.writeSlice(out, result);
            }
        } catch (Exception e) {
            throw new RuntimeException(e+" -------- output err");
        }
    }
}

状态保存接口:

@AccumulatorStateMetadata(stateSerializerClass = CollectListStatsSerializer.class, stateFactoryClass = CollectListStatsFactory.class)
public interface CollectState extends AccumulatorState {
    CollectListStats get();
    void set(CollectListStats value);
}

存放数据的类:此类需要实现数据的序列化和反序列化,这是最关键和比较麻烦的地方,贴一个例子,关键在于需要自己控制存储空间以及数据的顺序,和读取的时候按照一定顺序读取。对于字符要先存储长度,然后是字节,读取则先读取字符长度,然后读取这么长的数据,最后转化为字符。

public class CollectListStats {
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(CollectListStats.class).instanceSize();
    //<id,<key,value>>
    private Map<Integer,Map<String,Integer>> collectContainer = new HashMap<>();
    private long contentEstimatedSize = 0;
    private int keyByteLen = 0;
    private int keyListLen = 0;
    CollectListStats() {
    }
    CollectListStats(Slice serialized) {
        deserialize(serialized);
    }
   void addCollectList(Integer id, String key, int value) {
        if (collectContainer.containsKey(id)) {
            Map<String, Integer> tmpMap = collectContainer.get(id);
            if (tmpMap.containsKey(key)) {
                tmpMap.put(key, tmpMap.get(key)+value);
            }else{
                tmpMap.put(key,value);
                contentEstimatedSize += ( key.getBytes().length + SizeOf.SIZE_OF_INT*);
                keyByteLen += key.getBytes().length;
                keyListLen++;
            }
        } else {
            Map<String,Integer> tmpMap = new HashMap<String,Integer>();
            tmpMap.put(key, value);
            keyByteLen += key.getBytes().length;
            keyListLen++;
            collectContainer.put(id, tmpMap);
            contentEstimatedSize += SizeOf.SIZE_OF_INT;
        }
    }
    //[{id:1,{"aaa":3,"fadf":6},{}]
    Slice getCollectResult() {
        Slice jsonSlice = null;
        try {
            StringBuilder jsonStr = new StringBuilder();
            jsonStr.append("[");
            int collectLength = collectContainer.entrySet().size();
            for (Map.Entry<Integer, Map<String, Integer>> mapEntry : collectContainer.entrySet()) {
                Integer id = mapEntry.getKey();
                Map<String, Integer> vMap = mapEntry.getValue();
                jsonStr.append("{id:").append(id).append(",{");
                int vLength = vMap.entrySet().size();
                for (Map.Entry<String, Integer> vEntry : vMap.entrySet()) {
                    String key = vEntry.getKey();
                    Integer value = vEntry.getValue();
                    jsonStr.append(key).append(":").append(value);
                    vLength--;
                    if (vLength != 0) {
                        jsonStr.append(",");
                    }
                }
                jsonStr.append("}");
                collectLength--;
                if (collectLength != 0) {
                    jsonStr.append(",");
                }
            }
            jsonStr.append("]");
            jsonSlice = Slices.utf8Slice*(jsonStr.toString());
        } catch (Exception e) {
            throw new RuntimeException(e+" ---------- get CollectResult err");
        }
        return jsonSlice;
    }
    public void deserialize(Slice serialized) {
        try {
            SliceInput input = serialized.getInput();
            //外层map的长度
            int collectStatsEntrySize = input.readInt();
            for (int collectCnt = 0; collectCnt < collectStatsEntrySize; collectCnt++) {

                int id = input.readInt();
                int keyEntrySize = input.readInt();
                for (int idCnt = 0; idCnt < keyEntrySize; idCnt++) {
                    int keyBytesLen = input.readInt();
                    byte[] keyBytes = new byte[keyBytesLen];
                    for (int byteIdx = 0; byteIdx < keyBytesLen; byteIdx++) {
                        keyBytes[byteIdx] = input.readByte();
                    }
                    String key = new String(keyBytes);
                    int value = input.readInt();
                    addCollectList(id, key, value);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e+" ----- deserialize err");
        }
    }

    public Slice serialize() {
        SliceOutput builder = null;
        int requiredBytes =                                                 //对应 SliceOutput builder append的内容所占用的空间
                SizeOf.SIZE_OF_INT*3                                      //id entry数目,id数值,key Entry数目
                \+ keyListLen * SizeOf.SIZE_OF_INT*                           //key bytes长度
                \+ keyByteLen*                                                //key byte总长度
                \+ keyListLen * SizeOf.SIZE_OF_INT;                          //value
        try {
            // 序列化
            builder = Slices.*allocate*(requiredBytes).getOutput();
            for (Map.Entry<Integer,Map<String, Integer>> entry : collectContainer.entrySet()) {
                //id个数
                builder.appendInt(collectContainer.entrySet().size());
                //id 数值
                builder.appendInt(entry.getKey());
                Map<String, Integer> kMap = entry.getValue();
                builder.appendInt(kMap.entrySet().size());
                for (Map.Entry<String, Integer> vEntry : kMap.entrySet()) {
                    byte[] keyBytes = vEntry.getKey().getBytes();
                    builder.appendInt(keyBytes.length);
                    builder.appendBytes(keyBytes);
                    builder.appendInt(vEntry.getValue());
                }
            }
            return builder.getUnderlyingSlice();
        } catch (Exception e) {
            throw new RuntimeException(e+" ---- serialize err  requiredBytes = " + requiredBytes + " keyByteLen= " + keyByteLen + " keyListLen = " + keyListLen);
       }
    }
    long estimatedInMemorySize() {
        return INSTANCE_SIZE + contentEstimatedSize;
    }
    void mergeWith(CollectListStats other) {
        if (other == null) {
            return;
        }
       for (Map.Entry<Integer,Map<String, Integer>> cEntry : other.collectContainer.entrySet()) {
            Integer id = cEntry.getKey();
            Map<String, Integer> kMap = cEntry.getValue();
            for (Map.Entry<String, Integer> kEntry : kMap.entrySet()) {
                addCollectList(id, kEntry.getKey(), kEntry.getValue());
            }
        }
    }
}

上面的例子我是直接从别人那儿拿过来的(个人比较懒:https://www.cnblogs.com/lrxvx/p/12558902.html),实际方式也很简单,就是实现序列化和反序列化方法以及一个管理存储空间的方法。需要注意的是序列化和反序列化时候的顺序一定要保证,Presto提供了许多属性方式的选项如int、long、byte,对于String方式序列化,可以将String转为byte再进行序列化,思路就是先序列化一个长度进去,再将字节内容序列化,反序列化的时候先读length,再读相应的字节内容转为String就好了,而对象类型的属性,本质上还是可以直接序列化属性,反序列化时候重新创建对象,内容没变。Presto的序列化方式比较高效,原因是因为我可以只序列化我想要的属性就好了,缺点是扩展性不足。

序列化类:

public class CollectListStatsSerializer implements AccumulatorStateSerializer<CollectState> {
    @Override
    public Type getSerializedType() {
        return VARBINARY;
    }
    @Override
    public void serialize(CollectState state, BlockBuilder out) {
        if (state.get() == null) {
            out.appendNull();
        } else {
            VARBINARY.writeSlice(out, state.get().serialize());
        }
    }
    @Override
    public void deserialize(Block block, int index, CollectState state) {
        state.set(new CollectListStats(VARBINARY.getSlice(block, index)));
    }
}

序列化工厂类:

public class CollectListStatsFactory implements AccumulatorStateFactory<CollectState> {
    @Override
    public CollectState createSingleState() {
        return new SingleState();
    }
    @Override
    public Class<? extends CollectState> getSingleStateClass() {
        return SingleState.class;
    }
    @Override
    public CollectState createGroupedState() {
        return new GroupState();
    }
    @Override
    public Class<? extends CollectState> getGroupedStateClass() {
        return GroupState.class;
    }
    public static class GroupState implements GroupedAccumulatorState, CollectState {
        private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedDigestAndPercentileState.class).instanceSize();
        private final ObjectBigArray<CollectListStats> collectStatsList = new ObjectBigArray<>();
        private long size;
        private long groupId;
        @Override
        public void setGroupId(long groupId) {
            this.groupId = groupId;
        }
        @Override
        public void ensureCapacity(long size) {
            collectStatsList.ensureCapacity(size);
        }
        @Override
        public CollectListStats get() {
            return collectStatsList.get(groupId);
        }
        @Override
        public void set(CollectListStats value) {
            CollectListStats previous = get();
            if (previous != null) {
                size -= previous.estimatedInMemorySize();
            }
            collectStatsList.set(groupId, value);
            size += value.estimatedInMemorySize();
        }
        @Override
        public long getEstimatedSize() {
            return INSTANCE_SIZE +size + collectStatsList.sizeOf();
        }
    }
    public static class SingleState implements CollectState{
        private CollectListStats stats;
        @Override
        public CollectListStats get() {
            return stats;
        }
        @Override
        public void set(CollectListStats value) {
            stats = value;
        }
        @Override
        public long getEstimatedSize() {
            if (stats == null) {
                return 0;
            }
            return stats.estimatedInMemorySize();
        }
    }
}

验证自定义函数

当我们开发好自定义函数后如何验证呢,一种方式是使用Presto内置函数注册机制进行单元测试,Presto 函数由MetadataManager中的FunctionRegistry进行管理,开发的函数要生效必须要先注册到FunctionRegistry中。函数注册是在 Presto 服务启动过程中进行的,有以下两种方式进行函数注册。

FunctionListBuilder builder = new FunctionListBuilder()
                .window(RowNumberFunction.class)
                .aggregate(ApproximateCountDistinctAggregation.class)
                .scalar(RepeatFunction.class)
                .function(MAP_HASH_CODE)
                ......

注册好之后就可以编写相应的单元测试代码了。完整的Aggregation测试代码如下:

import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.FunctionListBuilder;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import static com.facebook.presto.block.BlockAssertions.createSlicesBlock;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;

public class TestAggregation{


    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
    private static InternalAggregationFunction getAggregation(Type... arguments)
    {
        return FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction("presto_collect", fromTypes(arguments)));
    }
    private static final InternalAggregationFunction COLLECTION_AGGREGATION = getAggregation(VARCHAR, VARCHAR); //和Aggregation中的类型对应,java类型的Slice对应Varchar

    @BeforeClass
    public void init(){
        FunctionListBuilder builder = new FunctionListBuilder().aggregate(CollectListAggregation.class);
        FUNCTION_AND_TYPE_MANAGER.registerBuiltInFunctions(builder.getFunctions());
    }

    @Test
    public void collectionAggregationTest(){
        String result="xxx"; //你期望的aggregation结果
        Slice str1= Slices.utf8Slice("x");
        Slice str2= Slices.utf8Slice("y");
        assertAggregation(
                COLLECTION_AGGREGATION,
                result,
                createSlicesBlock(str1, str2),
                createSlicesBlock(str1, str2));
    }
}

标量函数单元测试

而对于标量函数scalar的测试略有不同,示例如下:

public class TestBitwiseFunctions
        extends AbstractTestFunctions
{
    @Test
    public void testBitCount()
    {
        assertFunction("bit_count(0, 64)", BIGINT, 0L); //bit_count为标量函数名,传参,参数如果为String则用单引号,参数类型,期望结果
    }
}

当然进行单元测试后,我们期望到真实的库中去验证,内置函数满足不了使用需求时,就需要自行开发函数来拓展函数库。开发者自行编写的拓展函数一般通过插件的方式进行注册。PluginManager在安装插件时会调用插件的getFunctions()方法,将获取到的函数集合通过MetadataManager的addFunctions方法进行注册:


public class ExampleFunctionsPlugin
        implements Plugin
{
    @Override
    public Set<Class<?>> getFunctions()
    {
        return ImmutableSet.<Class<?>>builder()
                .add(ExampleNullFunction.class)
                .add(IsNullFunction.class)
                .add(IsEqualOrNullFunction.class)
                .add(ExampleStringFunction.class)
                .add(ExampleAverageFunction.class)
                .build();
    }
}

Presto 函数的注册机制,新增和修改函数后,必须要重启服务才能生效,所以目前还不支持真正的用户自定义函数。插件函数进行注册之后,在resource下创建META-INF/services目录,并创建文件名为com.facebook.presto.spi.Plugin的文件,并添加内容:

xxx.xxx.xxx.ExampleFunctionsPlugin

然后利用presto的插件打包,此时会在target目录下生成zip文件,把xxx.zip解压到${PRESTOHOME}/plugin,重启presto服务即可进行验证。

总的来说,Presto的UDF和UDAF开发总结为一张图:
在这里插入图片描述
注意各个版本的Presto源码有所不同,遇到类不正确的对版本进行调整,上面是用的Presto版本为0.264,更多的参考Presto的官方源码https://github.com/prestodb/presto,而对于Persto的分组聚合查询流程可以参见:Presto中的分组聚合查询流程

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

童话ing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值