我正在尝试使用Apache Spark SQL在Java中创建用户定义的聚合函数(UDAF),该函数在完成时返回多个数组.我在网上搜索过,找不到任何关于如何做到这一点的例子或建议.
我能够返回单个数组,但无法弄清楚如何在evaluate()方法中以正确的格式获取数据以返回多个数组.
UDAF确实有效,因为我可以在evaluate()方法中打印出数组,我只是无法弄清楚如何将这些数组返回到调用代码(下面显示以供参考).
UserDefinedAggregateFunction customUDAF = new CustomUDAF();
DataFrame resultingDataFrame = dataFrame.groupBy().agg(customUDAF.apply(dataFrame.col("long_col"), dataFrame.col("double_col"))).as("processed_data");
我在下面包含了整个自定义UDAF类,但关键方法是dataType()和evaluate方法(),它们首先显示.
任何帮助或建议将不胜感激.谢谢.
public class CustomUDAF extends UserDefinedAggregateFunction {
@Override
public DataType dataType() {
// TODO: Is this the correct way to return 2 arrays?
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public Object evaluate(Row buffer) {
// Data conversion
List longList = new ArrayList(buffer.getList(0));
List dataList = new ArrayList(buffer.getList(1));
// Processing of data (omitted)
// TODO: How to get data into format needed to return 2 arrays?
return dataList;
}
@Override
public StructType inputSchema() {
return new StructType().add("long", DataTypes.LongType).add("data", DataTypes.DoubleType);
}
@Override
public StructType bufferSchema() {
return new StructType().add("longArray", DataTypes.createArrayType(DataTypes.LongType, false))
.add("dataArray", DataTypes.createArrayType(DataTypes.DoubleType, false));
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, new ArrayList());
buffer.update(1, new ArrayList());
}
@Override
public void update(MutableAggregationBuffer buffer, Row row) {
ArrayList longList = new ArrayList(buffer.getList(0));
longList.add(row.getLong(0));
ArrayList dataList = new ArrayList(buffer.getList(1));
dataList.add(row.getDouble(1));
buffer.update(0, longList);
buffer.update(1, dataList);
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
ArrayList longList = new ArrayList(buffer1.getList(0));
longList.addAll(buffer2.getList(0));
ArrayList dataList = new ArrayList(buffer1.getList(1));
dataList.addAll(buffer2.getList(1));
buffer1.update(0, longList);
buffer1.update(1, dataList);
}
@Override
public boolean deterministic() {
return true;
}
}
更新:基于zero323的答案,我能够使用以下命令返回两个数组:
return new Tuple2<>(longArray, dataArray);
从中获取数据有点困难,但涉及将DataFrame解构为Java列表,然后将其构建回DataFrame.