Presto 在交互式查询任务中担当着重要的职责。随着越来越多的人开始使用 SQL 在 Presto 上分析数据,我们发现需要将一些业务逻辑开发成类似 Hive 中的 UDF,提高 SQL 使用人员的效率,同时也保证 Hive 和 Presto 环境中的 UDF 统一。
1、Presto函数介绍
在此之前先简单介绍下UDF和UDAF,UDF叫做用户自定义函数,而UDAF叫做用户自定义聚合函数,区别就在于UDF不会保存状态,一行输入一行输出,而UDAF会涉及到状态的保存,通过聚合多个节点的数据来转换为最终的输出结果。
在 Presto 中,函数大体分为三种:scalar,aggregation 和 window 类型。分别如下:
1)scalar标量函数,简单来说就是 Java 中的一个静态方法,本身没有任何状态(不保存数据,一行输入一行输出)。
2)aggregation累积状态的函数,或聚集函数,如count,avg。如果只是单节点,单机状态可以直接用一个变量存储即可,但是presto是分布式计算引擎,状态数据会在多个节点之间传输,因此状态数据需要被序列化成 Presto 的内部格式才可以被传输。简单来说Aggregation对应于多行输入一行输出。
3)window 窗口函数,窗口函数在查询结果的行上进行计算,执行顺序在HAVING子句之后,ORDER BY子句之前。在 Presto SQL 中,窗口函数的语法形式如下:
windowFunction(arg1,....argn) OVER([PARTITION BY<...>] [ORDER BY<...>] [RANGE|ROWS
BETWEEN AND])
窗口函数语法由关键字OVER触发,且包含三个子句:
PARTITION BY: 指定输入行分区的规则,类似于聚合函数的GROUP BY子句,不同分区里的计算互不干扰(窗口函数的计算是并发进行的,并发数和partition数量一致),缺省时将所有数据行视为一个分区
ORDER BY: 决定了窗口函数处理输入行的顺序
RANGE|ROWS BETWEEN AND: 指定窗口边界,不常用,缺省时的窗口为当前行所在的分区第一行到当前行。
2、自定义函数
官方文档地址:https://prestodb.io/docs/current/develop/functions.html
2.1自定义Scalar函数的实现
2.1.1定义一个java类
1)用 @ScalarFunction 的 Annotation 标记实现业务逻辑的静态方法。
2)用 @Description 描述函数的作用,这里的内容会在 SHOW FUNCTIONS 中显示。
3)用@SqlType 标记函数的返回值类型,如返回字符串,因此是 StandardTypes.VARCHAR。
4)Java 方法的返回值必须使用 Presto 内部的序列化方式,因此字符串类型必须返回 Slice, 使用 Slices.utf8Slice 方法可以方便的将 String 类型转换成 Slice 类型
public class ExampleStringFunction
{
@ScalarFunction("lowercaser")
@Description("converts the string to alternating case")
@SqlType(StandardTypes.VARCHAR)
public static Slice lowercaser(@SqlType(StandardTypes.VARCHAR) Slice slice)
{
String argument = slice.toStringUtf8();
return Slices.utf8Slice(argument.toLowerCase());
}
}
2.2 自定义Aggregation函数
2.2.1实现原理步骤
Presto 把 Aggregation 函数分解成三个步骤执行:
1、input(state, data): 针对每条数据,执行 input 函数。这个过程是并行执行的,因此在每个有数据的节点都会执行,最终得到多个累积的状态数据。
2、combine(state1, state2):将所有节点的状态数据聚合起来,多次执行,直至所有状态数据被聚合成一个最终状态,也就是 Aggregation 函数的输出结果。
3、output(final_state, out):最终输出结果到一个 BlockBuilder。
2.2.2 具体代码实现过程
1、定义一个 Java 类,使用 @AggregationFunction 标记为 Aggregation 函数
2、使用 @InputFunction、 @CombineFunction、@OutputFunction 分别标记计算函数、合并结果函数和最终输出函数在 Plugin 处注册 Aggregation 函数
3、一个继承AccumulatorState的State接口,get和set方法
4、并使用 @AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 类信息(stateFactoryClass指定)。自己写一个序列化类和一个工厂类。(复杂数据类型需要:自定义类保存状态、Map、List等)
简单类型Aggregation
对于简单数据类型的聚合函数编写比较简单,实现一个包含input、combine、output的aggregation和一个状态设定接口State提供get、set方法即可,不用去关心序列化和状态保存问题。
Aggregation:
@AggregationFunction("avg")
public final class IntervalYearToMonthAverageAggregation
{
private IntervalYearToMonthAverageAggregation() {}
@InputFunction
public static void input(LongAndDoubleState state, @SqlType(StandardTypes.INTERVAL_YEAR_TO_MONTH) long value)
{
state.setLong(state.getLong() + 1);
state.setDouble(state.getDouble() + value);
}
@CombineFunction
public static void combine(LongAndDoubleState state, LongAndDoubleState otherState)
{
state.setLong(state.getLong() + otherState.getLong());
state.setDouble(state.getDouble() + otherState.getDouble());
}
@OutputFunction(StandardTypes.INTERVAL_YEAR_TO_MONTH)
public static void output(LongAndDoubleState state, BlockBuilder out)
{
long count = state.getLong();
if (count == 0) {
out.appendNull();
}
else {
double value = state.getDouble();
INTERVAL_YEAR_MONTH.writeLong(out, round(value / count));
}
}
}
LongAndDoubleState :写一个接口实现继承自AccumulatorState类,提供get、set方法即可。
public interface LongAndDoubleState
extends AccumulatorState
{
long getLong();
void setLong(long value);
double getDouble();
void setDouble(double value);
}
复杂类型Aggregation
对于复杂数据类型则需要提供序列化机制,你要序列化那些东西都是由你来制指定的。在AccumulatorState的接口上用注解指定@AccumulatorStateMetadata 提供序列化(stateSerializerClass指定)和 Factory 类信息(stateFactoryClass指定),自定义一个序列化器和序列化工厂类,实现类的序列化和反序列化。
Aggregation类: 这个类实现比较简单,和简单数据类型的实现一样,input、combine、output。
@AggregationFunction("presto_collect")
public class CollectListAggregation {
@InputFunction
public static void input(@AggregationState CollectState state, @SqlType(StandardTypes.VARCHAR) Slice id,@SqlType(StandardTypes.VARCHAR) Slice key) {
try {
CollectListStats stats = state.get();
if (stats == null) {
stats = new CollectListStats();
state.set(stats);
}
int inputId = Integer.parseInt(id.toStringUtf8());
String inputKey = key.toStringUtf8();
stats.addCollectList(inputId,inputKey, 1);
} catch (Exception e) {
throw new RuntimeException(e+" --------- input err");
}
}
@CombineFunction
public static void combine(@AggregationState CollectState state, CollectState otherState) {
try {
CollectListStats collectListStats = state.get();
CollectListStats oCollectListStats = otherState.get();
if(collectListStats == null) {
state.set(oCollectListStats);
} else {
collectListStats.mergeWith(oCollectListStats);
}
}catch (Exception e) {
throw new RuntimeException(e+" --------- combine err");
}
}
@OutputFunction(StandardTypes.*VARCHAR*)
public static void output(@AggregationState CollectState state, BlockBuilder out) {
try {
CollectListStats stats = state.get();
if (stats == null) {
out.appendNull();
return;
}
// 统计
Slice result = stats.getCollectResult();
if (result == null) {
out.appendNull();
} else {
VarcharType.VARCHAR.writeSlice(out, result);
}
} catch (Exception e) {
throw new RuntimeException(e+" -------- output err");
}
}
}
状态保存接口:
@AccumulatorStateMetadata(stateSerializerClass = CollectListStatsSerializer.class, stateFactoryClass = CollectListStatsFactory.class)
public interface CollectState extends AccumulatorState {
CollectListStats get();
void set(CollectListStats value);
}
存放数据的类:此类需要实现数据的序列化和反序列化,这是最关键和比较麻烦的地方,贴一个例子,关键在于需要自己控制存储空间以及数据的顺序,和读取的时候按照一定顺序读取。对于字符要先存储长度,然后是字节,读取则先读取字符长度,然后读取这么长的数据,最后转化为字符。
public class CollectListStats {
private static final int INSTANCE_SIZE = ClassLayout.parseClass(CollectListStats.class).instanceSize();
//<id,<key,value>>
private Map<Integer,Map<String,Integer>> collectContainer = new HashMap<>();
private long contentEstimatedSize = 0;
private int keyByteLen = 0;
private int keyListLen = 0;
CollectListStats() {
}
CollectListStats(Slice serialized) {
deserialize(serialized);
}
void addCollectList(Integer id, String key, int value) {
if (collectContainer.containsKey(id)) {
Map<String, Integer> tmpMap = collectContainer.get(id);
if (tmpMap.containsKey(key)) {
tmpMap.put(key, tmpMap.get(key)+value);
}else{
tmpMap.put(key,value);
contentEstimatedSize += ( key.getBytes().length + SizeOf.SIZE_OF_INT*);
keyByteLen += key.getBytes().length;
keyListLen++;
}
} else {
Map<String,Integer> tmpMap = new HashMap<String,Integer>();
tmpMap.put(key, value);
keyByteLen += key.getBytes().length;
keyListLen++;
collectContainer.put(id, tmpMap);
contentEstimatedSize += SizeOf.SIZE_OF_INT;
}
}
//[{id:1,{"aaa":3,"fadf":6},{}]
Slice getCollectResult() {
Slice jsonSlice = null;
try {
StringBuilder jsonStr = new StringBuilder();
jsonStr.append("[");
int collectLength = collectContainer.entrySet().size();
for (Map.Entry<Integer, Map<String, Integer>> mapEntry : collectContainer.entrySet()) {
Integer id = mapEntry.getKey();
Map<String, Integer> vMap = mapEntry.getValue();
jsonStr.append("{id:").append(id).append(",{");
int vLength = vMap.entrySet().size();
for (Map.Entry<String, Integer> vEntry : vMap.entrySet()) {
String key = vEntry.getKey();
Integer value = vEntry.getValue();
jsonStr.append(key).append(":").append(value);
vLength--;
if (vLength != 0) {
jsonStr.append(",");
}
}
jsonStr.append("}");
collectLength--;
if (collectLength != 0) {
jsonStr.append(",");
}
}
jsonStr.append("]");
jsonSlice = Slices.utf8Slice*(jsonStr.toString());
} catch (Exception e) {
throw new RuntimeException(e+" ---------- get CollectResult err");
}
return jsonSlice;
}
public void deserialize(Slice serialized) {
try {
SliceInput input = serialized.getInput();
//外层map的长度
int collectStatsEntrySize = input.readInt();
for (int collectCnt = 0; collectCnt < collectStatsEntrySize; collectCnt++) {
int id = input.readInt();
int keyEntrySize = input.readInt();
for (int idCnt = 0; idCnt < keyEntrySize; idCnt++) {
int keyBytesLen = input.readInt();
byte[] keyBytes = new byte[keyBytesLen];
for (int byteIdx = 0; byteIdx < keyBytesLen; byteIdx++) {
keyBytes[byteIdx] = input.readByte();
}
String key = new String(keyBytes);
int value = input.readInt();
addCollectList(id, key, value);
}
}
} catch (Exception e) {
throw new RuntimeException(e+" ----- deserialize err");
}
}
public Slice serialize() {
SliceOutput builder = null;
int requiredBytes = //对应 SliceOutput builder append的内容所占用的空间
SizeOf.SIZE_OF_INT*3 //id entry数目,id数值,key Entry数目
\+ keyListLen * SizeOf.SIZE_OF_INT* //key bytes长度
\+ keyByteLen* //key byte总长度
\+ keyListLen * SizeOf.SIZE_OF_INT; //value
try {
// 序列化
builder = Slices.*allocate*(requiredBytes).getOutput();
for (Map.Entry<Integer,Map<String, Integer>> entry : collectContainer.entrySet()) {
//id个数
builder.appendInt(collectContainer.entrySet().size());
//id 数值
builder.appendInt(entry.getKey());
Map<String, Integer> kMap = entry.getValue();
builder.appendInt(kMap.entrySet().size());
for (Map.Entry<String, Integer> vEntry : kMap.entrySet()) {
byte[] keyBytes = vEntry.getKey().getBytes();
builder.appendInt(keyBytes.length);
builder.appendBytes(keyBytes);
builder.appendInt(vEntry.getValue());
}
}
return builder.getUnderlyingSlice();
} catch (Exception e) {
throw new RuntimeException(e+" ---- serialize err requiredBytes = " + requiredBytes + " keyByteLen= " + keyByteLen + " keyListLen = " + keyListLen);
}
}
long estimatedInMemorySize() {
return INSTANCE_SIZE + contentEstimatedSize;
}
void mergeWith(CollectListStats other) {
if (other == null) {
return;
}
for (Map.Entry<Integer,Map<String, Integer>> cEntry : other.collectContainer.entrySet()) {
Integer id = cEntry.getKey();
Map<String, Integer> kMap = cEntry.getValue();
for (Map.Entry<String, Integer> kEntry : kMap.entrySet()) {
addCollectList(id, kEntry.getKey(), kEntry.getValue());
}
}
}
}
上面的例子我是直接从别人那儿拿过来的(个人比较懒:https://www.cnblogs.com/lrxvx/p/12558902.html),实际方式也很简单,就是实现序列化和反序列化方法以及一个管理存储空间的方法。需要注意的是序列化和反序列化时候的顺序一定要保证,Presto提供了许多属性方式的选项如int、long、byte,对于String方式序列化,可以将String转为byte再进行序列化,思路就是先序列化一个长度进去,再将字节内容序列化,反序列化的时候先读length,再读相应的字节内容转为String就好了,而对象类型的属性,本质上还是可以直接序列化属性,反序列化时候重新创建对象,内容没变。Presto的序列化方式比较高效,原因是因为我可以只序列化我想要的属性就好了,缺点是扩展性不足。
序列化类:
public class CollectListStatsSerializer implements AccumulatorStateSerializer<CollectState> {
@Override
public Type getSerializedType() {
return VARBINARY;
}
@Override
public void serialize(CollectState state, BlockBuilder out) {
if (state.get() == null) {
out.appendNull();
} else {
VARBINARY.writeSlice(out, state.get().serialize());
}
}
@Override
public void deserialize(Block block, int index, CollectState state) {
state.set(new CollectListStats(VARBINARY.getSlice(block, index)));
}
}
序列化工厂类:
public class CollectListStatsFactory implements AccumulatorStateFactory<CollectState> {
@Override
public CollectState createSingleState() {
return new SingleState();
}
@Override
public Class<? extends CollectState> getSingleStateClass() {
return SingleState.class;
}
@Override
public CollectState createGroupedState() {
return new GroupState();
}
@Override
public Class<? extends CollectState> getGroupedStateClass() {
return GroupState.class;
}
public static class GroupState implements GroupedAccumulatorState, CollectState {
private static final int INSTANCE_SIZE = ClassLayout.parseClass(GroupedDigestAndPercentileState.class).instanceSize();
private final ObjectBigArray<CollectListStats> collectStatsList = new ObjectBigArray<>();
private long size;
private long groupId;
@Override
public void setGroupId(long groupId) {
this.groupId = groupId;
}
@Override
public void ensureCapacity(long size) {
collectStatsList.ensureCapacity(size);
}
@Override
public CollectListStats get() {
return collectStatsList.get(groupId);
}
@Override
public void set(CollectListStats value) {
CollectListStats previous = get();
if (previous != null) {
size -= previous.estimatedInMemorySize();
}
collectStatsList.set(groupId, value);
size += value.estimatedInMemorySize();
}
@Override
public long getEstimatedSize() {
return INSTANCE_SIZE +size + collectStatsList.sizeOf();
}
}
public static class SingleState implements CollectState{
private CollectListStats stats;
@Override
public CollectListStats get() {
return stats;
}
@Override
public void set(CollectListStats value) {
stats = value;
}
@Override
public long getEstimatedSize() {
if (stats == null) {
return 0;
}
return stats.estimatedInMemorySize();
}
}
}
验证自定义函数
当我们开发好自定义函数后如何验证呢,一种方式是使用Presto内置函数注册机制进行单元测试,Presto 函数由MetadataManager中的FunctionRegistry进行管理,开发的函数要生效必须要先注册到FunctionRegistry中。函数注册是在 Presto 服务启动过程中进行的,有以下两种方式进行函数注册。
FunctionListBuilder builder = new FunctionListBuilder()
.window(RowNumberFunction.class)
.aggregate(ApproximateCountDistinctAggregation.class)
.scalar(RepeatFunction.class)
.function(MAP_HASH_CODE)
......
注册好之后就可以编写相应的单元测试代码了。完整的Aggregation测试代码如下:
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.FunctionListBuilder;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import static com.facebook.presto.block.BlockAssertions.createSlicesBlock;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.operator.aggregation.AggregationTestUtils.assertAggregation;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
public class TestAggregation{
private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
private static InternalAggregationFunction getAggregation(Type... arguments)
{
return FUNCTION_AND_TYPE_MANAGER.getAggregateFunctionImplementation(FUNCTION_AND_TYPE_MANAGER.lookupFunction("presto_collect", fromTypes(arguments)));
}
private static final InternalAggregationFunction COLLECTION_AGGREGATION = getAggregation(VARCHAR, VARCHAR); //和Aggregation中的类型对应,java类型的Slice对应Varchar
@BeforeClass
public void init(){
FunctionListBuilder builder = new FunctionListBuilder().aggregate(CollectListAggregation.class);
FUNCTION_AND_TYPE_MANAGER.registerBuiltInFunctions(builder.getFunctions());
}
@Test
public void collectionAggregationTest(){
String result="xxx"; //你期望的aggregation结果
Slice str1= Slices.utf8Slice("x");
Slice str2= Slices.utf8Slice("y");
assertAggregation(
COLLECTION_AGGREGATION,
result,
createSlicesBlock(str1, str2),
createSlicesBlock(str1, str2));
}
}
标量函数单元测试
而对于标量函数scalar的测试略有不同,示例如下:
public class TestBitwiseFunctions
extends AbstractTestFunctions
{
@Test
public void testBitCount()
{
assertFunction("bit_count(0, 64)", BIGINT, 0L); //bit_count为标量函数名,传参,参数如果为String则用单引号,参数类型,期望结果
}
}
当然进行单元测试后,我们期望到真实的库中去验证,内置函数满足不了使用需求时,就需要自行开发函数来拓展函数库。开发者自行编写的拓展函数一般通过插件的方式进行注册。PluginManager在安装插件时会调用插件的getFunctions()方法,将获取到的函数集合通过MetadataManager的addFunctions方法进行注册:
public class ExampleFunctionsPlugin
implements Plugin
{
@Override
public Set<Class<?>> getFunctions()
{
return ImmutableSet.<Class<?>>builder()
.add(ExampleNullFunction.class)
.add(IsNullFunction.class)
.add(IsEqualOrNullFunction.class)
.add(ExampleStringFunction.class)
.add(ExampleAverageFunction.class)
.build();
}
}
Presto 函数的注册机制,新增和修改函数后,必须要重启服务才能生效,所以目前还不支持真正的用户自定义函数。插件函数进行注册之后,在resource下创建META-INF/services
目录,并创建文件名为com.facebook.presto.spi.Plugin
的文件,并添加内容:
xxx.xxx.xxx.ExampleFunctionsPlugin
然后利用presto的插件打包,此时会在target目录下生成zip文件,把xxx.zip解压到${PRESTOHOME}/plugin
,重启presto服务即可进行验证。
总的来说,Presto的UDF和UDAF开发总结为一张图:
注意:各个版本的Presto源码有所不同,遇到类不正确的对版本进行调整,上面是用的Presto版本为0.264,更多的参考Presto的官方源码
https://github.com/prestodb/presto,而对于Persto的分组聚合查询流程可以参见:Presto中的分组聚合查询流程