自定义udaf函数,首先我们要继承UserDefinedAggregateFunction 来实现自定义聚合函数。
首先我们先来看下该类的一些基本信息。
abstract class UserDefinedAggregateFunction extends Serializable {
StructType代表的是该聚合函数输入参数的类型。例如,一个UDAF实现需要两个输入参数,
类型分别是DoubleType和LongType,那么该StructType格式如下:
new StructType()
.add("doubleInput",DoubleType)
.add("LongType",LongType)
那么该udaf就只会识别,这种类型的输入的数据。
def inputSchema: StructType
该StructType代表aggregation buffer的类型参数。例如,一个udaf的buffer有
两个值,类型分别是DoubleType和LongType,那么其格式将会如下:
new StructType()
.add("doubleInput", DoubleType)
.add("longInput", LongType)
也只会适用于类型格式如上的数据
def bufferSchema: StructType
dataTypeda代表该UDAF的返回值类型
def dataType: DataType
如果该函数是确定性的,那么将会返回true,例如,给相同的输入,就会有相同
的输出
def deterministic: Boolean
初始化聚合buffer,例如,给聚合buffer以0值
在两个初始buffer调用聚合函数,其返回值应该是初始函数自身,例如
merge(initialBuffer,initialBuffer)应该等于initialBuffer。
def initialize(buffer: MutableAggregationBuffer): Unit
利用输入去更新给定的聚合buffer,每个输入行都会调用一次该函数
def update(buffer: MutableAggregationBuffer, input: Row): Unit
合并两个聚合buffer,并且将更新的buffer返回给buffer1
该函数在聚合并两个部分聚合数据集的时候调用
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
计算该udaf在给定聚合buffer上的最终结果
def evaluate(buffer: Row): Any
使用给定的Column作为输入参数,来为当前UDAF创建一个Column
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
}
使用给定Column去重后的值作为参数来生成一个Column
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
}
/**
* A `Row` representing a mutable aggregation buffer.
*
* This is not meant to be extended outside of Spark.
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** Update the ith value of this buffer. */
def update(i: Int, value: Any): Unit
}
我们拿个需求来练习下
首先在 Hive 中创建表, 并导入数据.
一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表
CREATE TABLE `user_visit_action`(
`date` string,
`user_id` bigint,
`session_id` string,
`page_id` bigint,
`action_time` string,
`search_keyword` string,
`click_category_id` bigint,
`click_product_id` bigint,
`order_category_ids` string,
`order_product_ids` string,
`pay_category_ids` string,
`pay_product_ids` string,
`city_id` bigint)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/user_visit_action.txt' into table sparkpractice.user_visit_action;
CREATE TABLE `product_info`(
`product_id` bigint,
`product_name` string,
`extend_info` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/product_info.txt' into table sparkpractice.product_info;
CREATE TABLE `city_info`(
`city_id` bigint,
`city_name` string,
`area` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/city_info.txt' into table sparkpractice.city_info;
需求结果
计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
例如:
地区 | 商品名称 | 点击次数 | 城市备注 |
---|---|---|---|
华北 | 商品A | 100000 | 北京21.2%,天津13.2%,其他65.6% |
华北 | 商品P | 80200 | 北京63.0%,太原10%,其他27.0% |
华北 | 商品M | 666 | 北京63.0%,太原10%,其他27.0% |
华北 | 商品J | 222 | 大连28%,辽宁17.0%,其他 55.0% |
需求分析:
join所有表,获取所有要查询的字段
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数
具体实现:
步骤一:自定义UDAF函数
// 声明聚合函数
// 1. 继承UserDefinedAggregateFunction
// 2. 重写方法
class CalcCitydata extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
StructType(Array(StructField("cityName", StringType)))
}
// total, (北京 - 100,天津-50)
override def bufferSchema: StructType = {
StructType(Array(StructField("cityToCount", MapType(StringType, LongType)), StructField("total", LongType)))
}
override def dataType: DataType = StringType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Map[String, Long]()
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val cityName = input.getString(0)
val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
buffer(1) = buffer.getLong(1) + 1L
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val map1: Map[String, Long] = buffer1.getAs[Map[String, Long]](0)
val map2: Map[String, Long] = buffer2.getAs[Map[String, Long]](0)
buffer1(0) = map1.foldLeft(map2){
case ( map, (k, v) ) => {
map + (k -> (map.getOrElse(k, 0L) + v))
}
}
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
// 获取城市点击次数,并根据点击次数进行排序取2条
val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
val remarkList: List[(String, Long)] = map.toList.sortWith(
(left, right) => {
left._2 > right._2
}
)
if ( remarkList.size > 2 ) {
val restList: List[(String, Long)] = remarkList.take(2)
val cityList: List[String] = restList.map {
case (cityName, clickCount) => {
cityName + clickCount.toDouble / buffer.getLong(1) * 100 + "%"
}
}
cityList.mkString(", ") + ", 其他 " + ( remarkList.tail.tail.map(_._2).sum / buffer.getLong(1) * 100 + "%" )
} else {
val cityList: List[String] = remarkList.map {
case (cityName, clickCount) => {
cityName + clickCount.toDouble / buffer.getLong(1) * 100 + "%"
}
}
cityList.mkString(", ")
}
}
}
步骤二:业务实现
package com.atguigu.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, MapType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object SparkSQL04_Hive {
def main(args: Array[String]): Unit = {
// TODO 创建环境对象
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL04_Hive")
val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
import spark.implicits._
// 创建聚合函数
val calcCitydata = new CalcCitydata;
spark.udf.register("calcCitydata", calcCitydata)
// TODO 读取Hive中表的数据
spark.sql("use sparkpractice190513")
spark.sql(
"""
| select
| v.*,
| p.product_name,
| c.city_name,
| c.area
| from user_visit_action v
| join product_info p on v.click_product_id = p.product_id
| join city_info c on v.city_id = c.city_id
| where v.click_product_id > -1
""".stripMargin).createOrReplaceTempView("t1")
spark.sql(
"""
| select
| t1.area,
| t1.product_name,
| count(*) as areaProductClick,
| calcCitydata(t1.city_name)
| from t1
| group by t1.area, t1.product_name
""".stripMargin).createOrReplaceTempView("t2")
spark.sql(
"""
|select
| *
|from (
| select
| *,
| rank() over( partition by t2.area order by t2.areaProductClick desc ) as rank
| from t2
|) t3
|where rank < 3
""".stripMargin).show(20)
// TODO 释放资源
spark.stop()
}
}