首先初始化一个Dataset
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.{Dataset, Row, SparkSession, types}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import java.text.SimpleDateFormat
case class Order(detail_id:String,order_id:String,product_id:String,num:Int,amt:Double,order_time:String)
object udf_udaf_udtf {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("udf-udaf-udtf")
.master("local[*]").enableHiveSupport().getOrCreate()
import spark.implicits._
val arr0=Array(("d001","o001","jacket001",1,200.2,"2022-07-20 10:25:19"),
("d002","o001","shoe001",1,305.3,"2022-07-20 10:25:21"),
("d003","o002","tshirt001",2,158.8,"2022-07-21 10:26:01"),
("d004","o003","skirt001",1,88.6,"2022-07-21 18:26:32"),
("d005","o003","skirt002",1,62.1,"2022-07-21 18:27:02"),
("d006","o004","shoe001",1,305.3,"2022-07-22 12:27:33"),
("d007","o005","jacket001",2,400.4,"2022-07-25 15:19:58"),
("d008","o005","shoe001",2,610.6,"2022-07-25 15:19:59"),
("d009","o006","skirt002",2,124.2,"2022-07-27 13:09:55"),
("d010","o007","skirt001",1,88.6,"2022-07-27 14:06:55"))
//创建dataset
val orderDS = spark.sparkContext.parallelize(arr0).map({
line=>Order(line._1.toString,line._2.toString,line._3.toString,
line._4.toInt,line._5.toDouble,line._6.toString) }).toDS()
UDF注册及使用:
//定一个udf转换函数,标准日期时间格式转时间戳
def datetimeTrans(datetime:String):String = {
val fm = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
val dt = fm.parse(datetime)
val tim: Long = dt.getTime
tim.toString
}
/*==============================UDF的注册和使用方式======================================*/
spark.udf.register("datetimeTrans",datetimeTrans(_:String))
/*另一种注册udf的方式也可采用如:
* spark.udf.register("hob_num",(s:String)=>s.split(",").size)
* */
orderDS.createOrReplaceTempView("sales")
spark.sql("select detail_id,order_id,datetimeTrans(order_time) as tstamp from sales limit 10").show()
运行结果:
+---------+--------+-------------+
|detail_id|order_id| tstamp|
+---------+--------+-------------+
| d001| o001|1658283919000|
| d002| o001|1658283921000|
| d003| o002|1658370361000|
| d004| o003|1658399192000|
| d005| o003|1658399222000|
| d006| o004|1658464053000|
| d007| o005|1658733598000|
| d008| o005|1658733599000|
| d009| o006|1658898595000|
| d010| o007|1658902015000|
+---------+--------+-------------+
UDAF实现多行转一行的聚合效果,实现过程需要对官方的抽象类UserDefinedAggregateFunction进行详细的逻辑定义,如下:
/* 对官方抽象类UserDefinedAggregateFunction进行重写,定义udaf的逻辑
* 求均价sum(amt)/sum(num) 及 聚合维中出现的最大数量max(num)
**/
class MySumFunction extends UserDefinedAggregateFunction{ //自定义函数
//step1:定义输入数据的schema
override def inputSchema: StructType = {
//new StructType().add("age",LongType)
StructType(Seq(
StructField("num",IntegerType), //数量num 整型
StructField("amt",DoubleType) //金额amt double
))
}
//step2:定义缓存bufferSchema
//根据计算需求,要缓存sum(数量) sum(金额) 和 当前最大数量
override def bufferSchema: StructType = {
new StructType().add("sum_num",IntegerType)
.add("sum_amt",DoubleType)
.add("max_num",IntegerType)
}
//step3:定义聚合函数返回的数据结构
//override def dataType: DataType = DoubleType
override def dataType: DataType = ArrayType(DoubleType)
//step4:初始化缓存
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
buffer(1) = 0.00
buffer(2) = 0
}
//step5:定义每进一条数据,对缓存的计算规则
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getInt(0)+input.getInt(0) //num累加计算
buffer(1) = buffer.getDouble(1)+input.getDouble(1) //amt累加计算
buffer(2) = (if (buffer.getInt(2)>=input.getInt(0)) buffer.getInt(2) else input.getInt(0)) //找寻最大值逻辑
}
//step6:各缓存合并的逻辑
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0) //num累计
buffer1(1) = buffer1.getDouble(1)+buffer2.getDouble(1) //amt累计
buffer1(2) = (if (buffer1.getInt(2)>=buffer2.getInt(2)) buffer1.getInt(2) else buffer2.getInt(2)) //最大值逻辑
}
//step7:计算最终结果
override def evaluate(buffer: Row): Seq[Double] = {
//(buffer.getLong(0)/buffer.getLong(1)).toDouble //返回单个数据的形式
Seq(buffer.getDouble(1)/buffer.getInt(0), // 销售均价
buffer.getInt(2)) //最大销售数量
}
//聚合函数 相同的输入是否要相同的输出,聚合函数是否幂等操作
override def deterministic: Boolean = true
}
再定义注册及使用过程:
/*==============================UDAF的注册和使用方式====================================*/
spark.udf.register("mySum",new MySumFunction)
//sparkSQL的方式使用udaf
spark.sql(
"""
|select a.product_id as product_id,
| a.ms[0] as avg_price, --商品均价
| a.ms[1] as max_num --商品单次最大售卖数量
|from
|(select product_id,mySum(num,amt) as ms
|from sales group by product_id ) a
|limit 100
|""".stripMargin).show()
/*
* */
//dataset方式处理数据,并将集合拆分
val mySum = new MySumFunction
val ds1 = orderDS.groupBy("product_id").agg(mySum(col("num"),col("amt")).as("ms"))
val ds2 = ds1.withColumn("product_id",col("product_id"))
.withColumn("avg_price",$"ms".getItem(0))
.withColumn("max_num",$"ms".getItem(1))
ds2.show()
两种处理方式的结果:
sparksql UDAF结果:
+----------+------------------+-------+
|product_id| avg_price|max_num|
+----------+------------------+-------+
| shoe001| 305.3| 2.0|
| skirt002| 62.1| 2.0|
| skirt001| 88.6| 1.0|
| tshirt001| 79.4| 2.0|
| jacket001|200.19999999999996| 2.0|
+----------+------------------+-------+
ds2结果:
+----------+--------------------+------------------+-------+
|product_id| ms| avg_price|max_num|
+----------+--------------------+------------------+-------+
| shoe001| [305.3, 2.0]| 305.3| 2.0|
| skirt002| [62.1, 2.0]| 62.1| 2.0|
| skirt001| [88.6, 1.0]| 88.6| 1.0|
| tshirt001| [79.4, 2.0]| 79.4| 2.0|
| jacket001|[200.199999999999...|200.19999999999996| 2.0|
+----------+--------------------+------------------+-------+