Spark DataFrame 使用UDF实现UDAF的一种方法
1、Background
当我们使用Spark Dataframe的时候常常需要进行group by
操作,然后针对这一个group算出一个结果来。即所谓的聚合操作。
然而 Spark提供的aggregation
函数太少,常常不能满足我们的需要,怎么办呢?
Spark 贴心的提供了UDAF(User-defined aggregate function),听起来不错。
但是,这个函数实现起来太复杂,反正我是看的晕晕乎乎,难受的很。反倒是UDF的实现非常简单,无非是UDF针对所有行,UDAF针对一个group中的所有行。
So,两者在某种程度上是一样的。
2、如何用UDF实现UDAF的功能
举个例子来说明问题:
我们有一个dataframe是长这样的:
+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
| 1 | 1 | 7 |
| 1 | 12 | 9 |
| 1 | 30 | 8 |
| 1 | 18 | 1 |
| 1 | 19 | 13 |
| 1 | 15 | 20 |
| 2 | 41 | 2 |
| 2 | 50 | 19 |
| 2 | 16 | 11 |
| 2 | 27 | 5 |
| 3 | 83 | 6 |
| 3 | 91 | 15 |
| 3 | 10 | 8 |
我们想对它group by id
,然后对每一个group里的内容进行自定义操作。
比如寻找某一列第三大的数、通过某两列的数据计算出一个参数等等很多user-define
的操作。
抽象的步骤看这里:
STEP.1. 对想要操作的列执行
collect_list()
,生成新列,此时一个group就是一行。
+-------+--------------------------+-----------------------+
|groupid| column1 | column2 |
+-------+--------------------------+-----------------------+
| 1 | [1,12,30,18,19,15] | [7,9,8,1,13,20] |
| 2 | [41,50,16,27] | [2,19,11,5] |
| 3 | [83,91,10] | [6,15,8] |
STEP.2.写一个UDF,传入参数为上边生成的列,相当于传入了一个或多个数组。
import org.apache.spark.sql.functions._
def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => { // udf function
var balabala //各种要用到的自定义变量
var resultArray = Array.empty[(Int, Int, Int)]
for(column1.size): //遍历计算
result[i] = 对俩数组column1,column2进行某种计算操作 //一个group中第i行的结果
resultArray[i]=(column1[i],column2[i],result[i])
resultArray //返回值
})
STEP.3.UDF中可以对数组做任意操作,你对数组想怎么操作就怎么操作,最后返回一个数组就可以了,长度和你传入的数组相同(显然),数组每个元素的格式是tuple的(column1.vaule,column2.value, result)
因为column1,column2
的值我们后边展开的时候还要用。
STEP.4.执行UDF函数,传入的第一步中生成的列,获得结果列newcolumn,存储UDF的返回值。此时一个group还是一行。
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
|groupid|column1 |column2 |newcolumn |
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
|1 |[1, 12, 30, 18, 19, 15]|[7, 9, 8, 1, 13, 20]|[[15, 20, 35], [19, 13, 32], [18, 1, 19], [30, 8, 38], [12, 9, 21], [1, 7, 8]]|
|3 |[83, 91, 10] |[6, 15, 8] |[[10, 8, 18], [91, 15, 106], [83, 6, 89]] |
|2 |[41, 50, 16, 27] |[2, 19, 11, 5] |[[27, 5, 32], [16, 11, 27], [50, 19, 69], [41, 2, 43]] |
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
STEP.5.
column1,column2
可以丢掉了,因为用不到。
+-------+------------------------------------------------------------------------------+
|groupid|newcolumn |
+-------+------------------------------------------------------------------------------+
|1 |[[15, 20, 35], [19, 13, 32], [18, 1, 19], [30, 8, 38], [12, 9, 21], [1, 7, 8]]|
|3 |[[10, 8, 18], [91, 15, 106], [83, 6, 89]] |
|2 |[[27, 5, 32], [16, 11, 27], [50, 19, 69], [41, 2, 43]] |
+-------+------------------------------------------------------------------------------+
STEP.6.对结果列执行
explode(col("newcolumn"))
操作,相当于把数组撑开来到整个group中。
+-------+-------------+
|groupid|new |
+-------+-------------+
|1 |[15, 20, 35] |
|1 |[19, 13, 32] |
|1 |[18, 1, 19] |
|1 |[30, 8, 38] |
|1 |[12, 9, 21] |
|1 |[1, 7, 8] |
|3 |[10, 8, 18] |
|3 |[91, 15, 106]|
|3 |[83, 6, 89] |
|2 |[27, 5, 32] |
|2 |[16, 11, 27] |
|2 |[50, 19, 69] |
|2 |[41, 2, 43] |
+-------+-------------+
STEP.7.把tuple分开成三列
select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("resultcolumn")) //selecting as separate column
3、完整代码看这里
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, collect_list, explode, udf}
/**
* @class TestUADF
* @author yyz
* @date 2021/01/25 15:48
* Spark DataFrame 使用UDF实现UDAF的一种方法
* https://segmentfault.com/a/1190000014088377
* */
object TestUADF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.OFF)
val spark = SparkSession.builder().master("local").appName("AppName").getOrCreate()
val userData = Array((1,1,7),
(1,12,9),
(1,30,8),
(1,18,1),
(1,19,13),
(1,15,20),
(2,41,2),
(2,50,19),
(2,16,11),
(2,27,5),
(3,83,6),
(3,91,15),
(3,10,8))
val df = spark.createDataFrame(userData).toDF("groupid", "column1","column2")
df.show()
df.groupBy("groupid").agg(collect_list("column1").as("column1"),collect_list("column2").as("column2")) // 把要操作的列转换成数组,作为group的一个列属性。
.withColumn("newcolumn", createNewCol(col("column1"), col("column2"))) //把存储数组的列传入udf,返回一个新列
.drop("column1", "column2") //丢弃两个存储数组的列,因为用不到了
.withColumn("new", explode(col("newcolumn"))) //把新计算出来的内容从一行explode到整个group
.drop("newcolumn") //丢弃两个存储数组的列,因为用不到了
.select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("column3")) //selecting as separate column
.show(false)
}
def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => { // udf function
var i =0 //各种要用到的自定义变量
var result=0
var resultArray = Array.empty[(Int,Int,Int)]
for(i <- 0 to column1.size-1) { //遍历计算
result = column1(i)+column2(i) //对俩数组column1,column2进行某种计算操作 一个group中第i行的结果
// println("i= ",i)
// println("column1(i)= ",column1(i))
// println("column2(i)= ",column2(i))
// println("result= ",result)
// println((column1(i), column2(i), result))
resultArray=(column1(i), column2(i), result)+:resultArray
}
resultArray //返回值
})
}
输出
+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
| 1| 1| 7|
| 1| 12| 9|
| 1| 30| 8|
| 1| 18| 1|
| 1| 19| 13|
| 1| 15| 20|
| 2| 41| 2|
| 2| 50| 19|
| 2| 16| 11|
| 2| 27| 5|
| 3| 83| 6|
| 3| 91| 15|
| 3| 10| 8|
+-------+-------+-------+
+-------+------+-------+-------+
|groupid|rownum|column2|column3|
+-------+------+-------+-------+
|1 |15 |20 |35 |
|1 |19 |13 |32 |
|1 |18 |1 |19 |
|1 |30 |8 |38 |
|1 |12 |9 |21 |
|1 |1 |7 |8 |
|3 |10 |8 |18 |
|3 |91 |15 |106 |
|3 |83 |6 |89 |
|2 |27 |5 |32 |
|2 |16 |11 |27 |
|2 |50 |19 |69 |
|2 |41 |2 |43 |
+-------+------+-------+-------+
Process finished with exit code 0
参考:https://docs.cloudera.com/documentation/enterprise/5-6-x/topics/cm_mc_hive_udf.html
http://bdlabs.edureka.co/static/help/topics/cm_mc_hive_udf.html