Spark SQL的自定义udaf函数

自定义udaf函数,首先我们要继承UserDefinedAggregateFunction 来实现自定义聚合函数。
首先我们先来看下该类的一些基本信息。

abstract class UserDefinedAggregateFunction extends Serializable {
StructType代表的是该聚合函数输入参数的类型。例如,一个UDAF实现需要两个输入参数,
类型分别是DoubleType和LongType,那么该StructType格式如下:
    new StructType()
    .add("doubleInput",DoubleType)
    .add("LongType",LongType)
那么该udaf就只会识别,这种类型的输入的数据。
  def inputSchema: StructType
   该StructType代表aggregation buffer的类型参数。例如,一个udaf的buffer有
   两个值,类型分别是DoubleType和LongType,那么其格式将会如下:
     new StructType()
      .add("doubleInput", DoubleType)
      .add("longInput", LongType)
     也只会适用于类型格式如上的数据
   
  def bufferSchema: StructType

    dataTypeda代表该UDAF的返回值类型
  def dataType: DataType

    如果该函数是确定性的,那么将会返回true,例如,给相同的输入,就会有相同
    的输出
  def deterministic: Boolean
    
    初始化聚合buffer,例如,给聚合buffer以0值
    在两个初始buffer调用聚合函数,其返回值应该是初始函数自身,例如
    merge(initialBuffer,initialBuffer)应该等于initialBuffer。
  def initialize(buffer: MutableAggregationBuffer): Unit

    利用输入去更新给定的聚合buffer,每个输入行都会调用一次该函数
  def update(buffer: MutableAggregationBuffer, input: Row): Unit

    合并两个聚合buffer,并且将更新的buffer返回给buffer1
    该函数在聚合并两个部分聚合数据集的时候调用
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit

    计算该udaf在给定聚合buffer上的最终结果
  def evaluate(buffer: Row): Any


    使用给定的Column作为输入参数,来为当前UDAF创建一个Column
  @scala.annotation.varargs
  def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = false)
Column(aggregateExpression)
  }

    使用给定Column去重后的值作为参数来生成一个Column
  @scala.annotation.varargs
  def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
        Complete,
        isDistinct = true)
Column(aggregateExpression)
  }
}

/**
 * A `Row` representing a mutable aggregation buffer.
 *
 * This is not meant to be extended outside of Spark.
 *
 * @since 1.5.0
 */
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {

/** Update the ith value of this buffer. */
  def update(i: Int, value: Any): Unit
}

我们拿个需求来练习下
首先在 Hive 中创建表, 并导入数据.
一共有 3 张表: 1 张用户行为表, 1 张城市表, 1 张产品表


CREATE TABLE `user_visit_action`(
  `date` string,
  `user_id` bigint,
  `session_id` string,
  `page_id` bigint,
  `action_time` string,
  `search_keyword` string,
  `click_category_id` bigint,
  `click_product_id` bigint,
  `order_category_ids` string,
  `order_product_ids` string,
  `pay_category_ids` string,
  `pay_product_ids` string,
  `city_id` bigint)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/user_visit_action.txt' into table sparkpractice.user_visit_action;
 
CREATE TABLE `product_info`(
  `product_id` bigint,
  `product_name` string,
  `extend_info` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/product_info.txt' into table sparkpractice.product_info;
 
CREATE TABLE `city_info`(
  `city_id` bigint,
  `city_name` string,
  `area` string)
row format delimited fields terminated by '\t';
load data local inpath '/opt/module/datas/city_info.txt' into table sparkpractice.city_info;

需求结果
计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示。
例如:

地区商品名称点击次数城市备注
华北商品A100000北京21.2%,天津13.2%,其他65.6%
华北商品P80200北京63.0%,太原10%,其他27.0%
华北商品M666北京63.0%,太原10%,其他27.0%
华北商品J222大连28%,辽宁17.0%,其他 55.0%

需求分析:
join所有表,获取所有要查询的字段
按照地区和商品 id 分组, 统计出每个商品在每个地区的总点击次数
每个地区内按照点击次数降序排列
只取前三名. 并把结果保存在数据库中
城市备注需要自定义 UDAF 函数
具体实现:
步骤一:自定义UDAF函数

// 声明聚合函数
// 1. 继承UserDefinedAggregateFunction
// 2. 重写方法
class CalcCitydata extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = {
    StructType(Array(StructField("cityName", StringType)))
  }
 
  // total, (北京 - 100,天津-50)
  override def bufferSchema: StructType = {
    StructType(Array(StructField("cityToCount", MapType(StringType, LongType)), StructField("total", LongType)))
  }
 
  override def dataType: DataType = StringType
 
  override def deterministic: Boolean = true
 
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map[String, Long]()
    buffer(1) = 0L
  }
 
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val cityName = input.getString(0)
    val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
    buffer(0) = map + (cityName -> (map.getOrElse(cityName, 0L) + 1L))
    buffer(1) = buffer.getLong(1) + 1L
  }
 
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
 
    val map1: Map[String, Long] = buffer1.getAs[Map[String, Long]](0)
    val map2: Map[String, Long] = buffer2.getAs[Map[String, Long]](0)
 
    buffer1(0) = map1.foldLeft(map2){
      case ( map, (k, v) ) => {
        map + (k -> (map.getOrElse(k, 0L) + v))
      }
    }
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
 
  override def evaluate(buffer: Row): Any = {
 
    // 获取城市点击次数,并根据点击次数进行排序取2条
    val map: Map[String, Long] = buffer.getAs[Map[String, Long]](0)
 
    val remarkList: List[(String, Long)] = map.toList.sortWith(
      (left, right) => {
        left._2 > right._2
      }
    )
 
    if ( remarkList.size > 2 ) {
 
      val restList: List[(String, Long)] = remarkList.take(2)
      val cityList: List[String] = restList.map {
        case (cityName, clickCount) => {
          cityName + clickCount.toDouble / buffer.getLong(1) * 100 + "%"
        }
      }
      cityList.mkString(", ") + ", 其他 " + ( remarkList.tail.tail.map(_._2).sum / buffer.getLong(1) * 100 + "%" )
 
    } else {
      val cityList: List[String] = remarkList.map {
        case (cityName, clickCount) => {
          cityName + clickCount.toDouble / buffer.getLong(1) * 100 + "%"
        }
      }
      cityList.mkString(", ")
    }
 
 
  }

}

步骤二:业务实现

package com.atguigu.bigdata.spark.sql
 
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, MapType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
 
 
object SparkSQL04_Hive {
 
  def main(args: Array[String]): Unit = {
 
    // TODO 创建环境对象
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL04_Hive")
    val spark = SparkSession.builder().config(sparkConf).enableHiveSupport().getOrCreate()
    import spark.implicits._
 
    // 创建聚合函数
    val calcCitydata = new CalcCitydata;
 
    spark.udf.register("calcCitydata", calcCitydata)
 
    // TODO 读取Hive中表的数据
    spark.sql("use sparkpractice190513")
 
    spark.sql(
      """
        |            select
        |                v.*,
        |                p.product_name,
        |                c.city_name,
        |                c.area
        |            from user_visit_action v
        |            join product_info p on v.click_product_id = p.product_id
        |            join city_info c on v.city_id = c.city_id
        |            where v.click_product_id > -1
      """.stripMargin).createOrReplaceTempView("t1")
 
    spark.sql(
      """
        |    select
        |        t1.area,
        |        t1.product_name,
        |        count(*) as areaProductClick,
        |        calcCitydata(t1.city_name)
        |    from t1
        |    group by t1.area, t1.product_name
      """.stripMargin).createOrReplaceTempView("t2")
 
    spark.sql(
      """
        |select
        |   *
        |from (
        |    select
        |        *,
        |        rank() over( partition by t2.area order by t2.areaProductClick desc ) as rank
        |    from t2
        |) t3
        |where rank < 3
      """.stripMargin).show(20)
 
    // TODO 释放资源
    spark.stop()
 
  }
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值