SPARK-SQL-之UDF、UDAF

SPARK-SQL-之UDF、UDAF

1、UDF使用

// 注册函数    
spark.udf.register("prefix1", (name: String) => {
    "Name:" + name
})
// 使用函数
spark.sql("select *,prefix1(name) from users").show()

2、UDAF使用

2.1 弱类型

// 1 定义UDAF(弱类型、3.0.0之前得版本可以使用,没标记过时)
package com.shufang.rdd_ds_df

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

class MyUDAF extends UserDefinedAggregateFunction {
  // IN
  override def inputSchema: StructType = {
    StructType(
      Array(
        StructField("age", LongType)
      )
    )
  }

  // MIDDLE 缓冲区类型
  override def bufferSchema: StructType = {
    StructType(
      Array(
        StructField("total", LongType),
        StructField("count", LongType)
      )
    )
  }

  // OUT
  override def dataType: DataType = LongType

  // 函数的稳定性
  override def deterministic: Boolean = {
    true
  }

  // 缓冲器的初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    /*buffer(0) = 0L
    buffer(1) = 0L*/
    buffer.update(0, 0L)
    buffer.update(1, 0L)
  }

  // 根据输入的值更新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getLong(0) + input.getLong(0))
    buffer.update(1, buffer.getLong(1) + 1)
  }

  // 合并多个缓冲区
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0,buffer1.getLong(0) + buffer2.getLong(0))
    buffer1.update(1,buffer1.getLong(1) + buffer2.getLong(1))
  }

  // 计算平均值
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}


// 2 注册&使用
spark.udf.register("ageAvg", new MyUDAF)
spark.sql("select ageAvg(id) as av from users").show()

2.2 强类型(spark 3.0.0之后推荐使用)

// 1 声明并实现
package com.shufang.rdd_ds_df

import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}

/**
 * Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
 */
case class Buff(var total:Long ,var count:Long)
class MyUDAF1 extends Aggregator[Long,Buff,Long] {
  //缓冲区初始化
  override def zero: Buff = Buff(0L,0L)
  //将进来的元素与缓冲区进行合并
  override def reduce(b: Buff, a: Long): Buff = {
    b.count +=1
    b.total += a
    b
  }
  //合并多个缓冲区
  override def merge(b1: Buff, b2: Buff): Buff = {
    b1.count  = b1.count + b2.count
    b1.total  = b1.total + b2.total
    b1
  }

 // 计算最终结果
  override def finish(buff: Buff): Long = {
    buff.total/buff.count
  }

 // 定义序列化编码器
  override def bufferEncoder: Encoder[Buff] = Encoders.product
 //定义序列化编码器
  override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}


// 2 注册并使用,注册方式不一样
spark.udf.register("ageAvg", functions.udaf(new MyUDAF1()))
spark.sql("select ageAvg(id) as av from users").show()
 

2.3 早期版本使用强类型UDAF

如果是3.0.0之前的版本需要使用强类型,需要结合DSL sparkSQL的领域语言

// 1 声明,相当于DS的每一行相当于传入的参数
package com.shufang.rdd_ds_df

    import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}

/**
 * Aggregator[IN, BUF, OUT] should now be registered as a UDF" + via the functions.udaf(agg) method.", "3.0.0"
 */
//case class Buff(var total:Long ,var count:Long)
class MyUDAF2 extends Aggregator[User,Buff,Long] {
    //缓冲区初始化
    override def zero: Buff = Buff(0L,0L)

        override def reduce(b: Buff, a: User): Buff = {
        b.count +=1
            b.total += a.id
            b
    }

    override def merge(b1: Buff, b2: Buff): Buff = {
        b1.count  = b1.count + b2.count
            b1.total  = b1.total + b2.total
            b1
    }

    override def finish(buff: Buff): Long = {
        buff.total/buff.count
    }

    override def bufferEncoder: Encoder[Buff] = Encoders.product

        override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}

// 2 使用
val column: TypedColumn[User, Long] = new MyUDAF2().toColumn
val ds: Dataset[User] = df.as[User]
ds.select(column).show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Spark SQL中的自定义函数UDFUDAF、UDTF)是用户自己定义的函数,可以用于对数据进行处理和转换。下面是一些自定义函数的实例: 1. UDF(User-Defined Function):用户自定义函数,可以将一个或多个输入参数转换为输出值。例如,我们可以定义一个UDF来计算两个数的和: ``` import org.apache.spark.sql.functions.udf val sumUDF = udf((a: Int, b: Int) => a + b) val df = Seq((1, 2), (3, 4)).toDF("a", "b") df.select(sumUDF($"a", $"b")).show() ``` 2. UDAF(User-Defined Aggregate Function):用户自定义聚合函数,可以对一组数据进行聚合操作,例如求和、平均值等。例如,我们可以定义一个UDAF来计算一组数的平均值: ``` import org.apache.spark.sql.expressions.MutableAggregationBuffer import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.types._ class AvgUDAF extends UserDefinedAggregateFunction { // 输入数据类型 def inputSchema: StructType = StructType(StructField("value", DoubleType) :: Nil) // 聚合缓冲区数据类型 def bufferSchema: StructType = StructType( StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil ) // 输出数据类型 def dataType: DataType = DoubleType // 是否是确定性的 def deterministic: Boolean = true // 初始化聚合缓冲区 def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0.0 buffer(1) = 0L } // 更新聚合缓冲区 def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0) = buffer.getDouble(0) + input.getDouble(0) buffer(1) = buffer.getLong(1) + 1L } // 合并聚合缓冲区 def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) } // 计算最终结果 def evaluate(buffer: Row): Any = { buffer.getDouble(0) / buffer.getLong(1) } } val avgUDAF = new AvgUDAF() val df = Seq(1.0, 2.0, 3.0, 4.0).toDF("value") df.agg(avgUDAF($"value")).show() ``` 3. UDTF(User-Defined Table-Generating Function):用户自定义表生成函数,可以将一个或多个输入参数转换为一个表。例如,我们可以定义一个UDTF来将一个字符串拆分成多个单词: ``` import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{explode, udf} import org.apache.spark.sql.types._ class SplitUDTF extends UserDefinedFunction { // 输入数据类型 def inputSchema: StructType = StructType(StructField("value", StringType) :: Nil) // 输出数据类型 def dataType: DataType = ArrayType(StringType) // 是否是确定性的 def deterministic: Boolean = true // 计算结果 def apply(value: Row): Any = { value.getString(0).split(" ") } } val splitUDTF = udf(new SplitUDTF(), ArrayType(StringType)) val df = Seq("hello world", "spark sql").toDF("value") df.select(explode(splitUDTF($"value"))).show() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值