Spark DataFrame 使用UDF实现UDAF的一种方法

Spark DataFrame 使用UDF实现UDAF的一种方法

1、Background

当我们使用Spark Dataframe的时候常常需要进行group by操作,然后针对这一个group算出一个结果来。即所谓的聚合操作

然而 Spark提供的aggregation函数太少,常常不能满足我们的需要,怎么办呢?

Spark 贴心的提供了UDAF(User-defined aggregate function),听起来不错。
但是,这个函数实现起来太复杂,反正我是看的晕晕乎乎,难受的很。反倒是UDF的实现非常简单,无非是UDF针对所有行,UDAF针对一个group中的所有行。

So,两者在某种程度上是一样的。

2、如何用UDF实现UDAF的功能

举个例子来说明问题:
我们有一个dataframe是长这样的:

+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
|   1   |  1    |   7   |
|   1   |  12   |   9   |
|   1   |  30   |   8   |
|   1   |  18   |   1   |
|   1   |  19   |   13  |
|   1   |  15   |   20  |
|   2   |  41   |   2   |
|   2   |  50   |   19  |
|   2   |  16   |   11  |
|   2   |  27   |   5   |
|   3   |  83   |   6   |
|   3   |  91   |   15  |
|   3   |  10   |   8   |

我们想对它group by id,然后对每一个group里的内容进行自定义操作。
比如寻找某一列第三大的数、通过某两列的数据计算出一个参数等等很多user-define的操作。

抽象的步骤看这里:

STEP.1. 对想要操作的列执行 collect_list(),生成新列,此时一个group就是一行。
        +-------+--------------------------+-----------------------+
        |groupid|        column1           |        column2        |
        +-------+--------------------------+-----------------------+
        |   1   |  [1,12,30,18,19,15]  | [7,9,8,1,13,20]   |
        |   2   |      [41,50,16,27]       |      [2,19,11,5]      | 
        |   3   |        [83,91,10]        |      [6,15,8]         |
STEP.2.写一个UDF,传入参数为上边生成的列,相当于传入了一个或多个数组。
 import org.apache.spark.sql.functions._
    def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => {  // udf function
      var balabala  //各种要用到的自定义变量 
      var resultArray = Array.empty[(Int, Int, Int)]
      for(column1.size):  //遍历计算
          result[i] = 对俩数组column1,column2进行某种计算操作 //一个group中第i行的结果
      resultArray[i]=(column1[i],column2[i],result[i])
      resultArray   //返回值
    })    
STEP.3.UDF中可以对数组做任意操作,你对数组想怎么操作就怎么操作,最后返回一个数组就可以了,长度和你传入的数组相同(显然),数组每个元素的格式是tuple的 (column1.vaule,column2.value, result)因为 column1,column2的值我们后边展开的时候还要用。
STEP.4.执行UDF函数,传入的第一步中生成的列,获得结果列newcolumn,存储UDF的返回值。此时一个group还是一行。
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
|groupid|column1                |column2             |newcolumn                                                                     |
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
|1      |[1, 12, 30, 18, 19, 15]|[7, 9, 8, 1, 13, 20]|[[15, 20, 35], [19, 13, 32], [18, 1, 19], [30, 8, 38], [12, 9, 21], [1, 7, 8]]|
|3      |[83, 91, 10]           |[6, 15, 8]          |[[10, 8, 18], [91, 15, 106], [83, 6, 89]]                                     |
|2      |[41, 50, 16, 27]       |[2, 19, 11, 5]      |[[27, 5, 32], [16, 11, 27], [50, 19, 69], [41, 2, 43]]                        |
+-------+-----------------------+--------------------+------------------------------------------------------------------------------+
STEP.5. column1,column2可以丢掉了,因为用不到。
+-------+------------------------------------------------------------------------------+
|groupid|newcolumn                                                                     |
+-------+------------------------------------------------------------------------------+
|1      |[[15, 20, 35], [19, 13, 32], [18, 1, 19], [30, 8, 38], [12, 9, 21], [1, 7, 8]]|
|3      |[[10, 8, 18], [91, 15, 106], [83, 6, 89]]                                     |
|2      |[[27, 5, 32], [16, 11, 27], [50, 19, 69], [41, 2, 43]]                        |
+-------+------------------------------------------------------------------------------+
STEP.6.对结果列执行 explode(col("newcolumn"))操作,相当于把数组撑开来到整个group中。
+-------+-------------+
|groupid|new          |
+-------+-------------+
|1      |[15, 20, 35] |
|1      |[19, 13, 32] |
|1      |[18, 1, 19]  |
|1      |[30, 8, 38]  |
|1      |[12, 9, 21]  |
|1      |[1, 7, 8]    |
|3      |[10, 8, 18]  |
|3      |[91, 15, 106]|
|3      |[83, 6, 89]  |
|2      |[27, 5, 32]  |
|2      |[16, 11, 27] |
|2      |[50, 19, 69] |
|2      |[41, 2, 43]  |
+-------+-------------+
STEP.7.把tuple分开成三列

select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("resultcolumn"))  //selecting as separate column

3、完整代码看这里

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, collect_list, explode, udf}

/**
 * @class TestUADF
 * @author yyz
 * @date 2021/01/25 15:48
 * Spark DataFrame 使用UDF实现UDAF的一种方法
 * https://segmentfault.com/a/1190000014088377
 * */
object TestUADF {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.OFF)
    val spark = SparkSession.builder().master("local").appName("AppName").getOrCreate()
    val userData = Array((1,1,7),
      (1,12,9),
      (1,30,8),
      (1,18,1),
      (1,19,13),
      (1,15,20),
      (2,41,2),
      (2,50,19),
      (2,16,11),
      (2,27,5),
      (3,83,6),
      (3,91,15),
      (3,10,8))
    val df = spark.createDataFrame(userData).toDF("groupid", "column1","column2")
    df.show()

    df.groupBy("groupid").agg(collect_list("column1").as("column1"),collect_list("column2").as("column2")) // 把要操作的列转换成数组,作为group的一个列属性。
      .withColumn("newcolumn", createNewCol(col("column1"), col("column2")))  //把存储数组的列传入udf,返回一个新列
      .drop("column1", "column2") //丢弃两个存储数组的列,因为用不到了
      .withColumn("new", explode(col("newcolumn"))) //把新计算出来的内容从一行explode到整个group
      .drop("newcolumn") //丢弃两个存储数组的列,因为用不到了
      .select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("column3"))  //selecting as separate column
      .show(false)

  }

  def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => {  // udf function
    var i =0 //各种要用到的自定义变量
    var result=0
    var resultArray = Array.empty[(Int,Int,Int)]
    for(i <- 0 to column1.size-1) { //遍历计算
      result = column1(i)+column2(i)     //对俩数组column1,column2进行某种计算操作 一个group中第i行的结果
//      println("i= ",i)
//      println("column1(i)= ",column1(i))
//      println("column2(i)= ",column2(i))
//      println("result= ",result)
//      println((column1(i), column2(i), result))
      resultArray=(column1(i), column2(i), result)+:resultArray
    }
    resultArray   //返回值
  })


}

输出

+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
|      1|      1|      7|
|      1|     12|      9|
|      1|     30|      8|
|      1|     18|      1|
|      1|     19|     13|
|      1|     15|     20|
|      2|     41|      2|
|      2|     50|     19|
|      2|     16|     11|
|      2|     27|      5|
|      3|     83|      6|
|      3|     91|     15|
|      3|     10|      8|
+-------+-------+-------+

+-------+------+-------+-------+
|groupid|rownum|column2|column3|
+-------+------+-------+-------+
|1      |15    |20     |35     |
|1      |19    |13     |32     |
|1      |18    |1      |19     |
|1      |30    |8      |38     |
|1      |12    |9      |21     |
|1      |1     |7      |8      |
|3      |10    |8      |18     |
|3      |91    |15     |106    |
|3      |83    |6      |89     |
|2      |27    |5      |32     |
|2      |16    |11     |27     |
|2      |50    |19     |69     |
|2      |41    |2      |43     |
+-------+------+-------+-------+


Process finished with exit code 0

参考:https://docs.cloudera.com/documentation/enterprise/5-6-x/topics/cm_mc_hive_udf.html
          http://bdlabs.edureka.co/static/help/topics/cm_mc_hive_udf.html

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值