自定义弱类型
package com.chen.sparksql.func
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("UDAFDemo").master("local[2]").getOrCreate()
val df: DataFrame = spark.read.json("d:/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("mySum",new MySum)
spark.udf.register("myAvg",new MyAvg)
spark.sql("select mySum(salary),myAvg(salary) from user").show
spark.close()
}
}
class MySum extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0D
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val v: Double = input.getAs[Double](0)
buffer(0) = buffer.getDouble(0) + v
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
}
override def evaluate(buffer: Row): Any = buffer(0)
}
class MyAvg extends UserDefinedAggregateFunction {
override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType)::Nil)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0D
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val v: Double = input.getAs[Double](0)
buffer(0) = buffer.getDouble(0) + v
buffer(1) = buffer.getLong(1) + 1L
}
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}
自定义强类型
package com.chen.sparksql.func
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
case class Dog(name: String, age: Int)
case class AgeAvg(sum: Int, count: Int) {
def avg = sum.toDouble / count
}
object UDAFDemo2{
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("UDAFDemo").master("local[2]").getOrCreate()
import spark.implicits._
val ds: Dataset[Dog] = List(Dog("dahuang", 8), Dog("xiaohuang", 4), Dog("zhonghuang", 6)).toDS()
val avg: TypedColumn[Dog, Double] = new MyAvg2().toColumn.name("abg")
ds.select(avg).show
spark.close()
}
}
class MyAvg2 extends Aggregator[Dog, AgeAvg, Double] {
override def zero: AgeAvg = AgeAvg(0, 0)
override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
case Dog(name, age) => AgeAvg(b.sum + age, b.count + 1)
case _ => b
}
override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = AgeAvg(b1.sum + b2.sum, b1.count + b2.count)
override def finish(reduction: AgeAvg): Double = reduction.avg
override def bufferEncoder: Encoder[AgeAvg] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}