UDF实现
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
object Spark02_SparkSQL_UDF {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("prefixName", (name:String) => {
"Name: " + name
})
spark.sql("select age, prefixName(username) from user").show
spark.close()
}
}
UDAF实现
继承UserDefinedAggregateFunction
package sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
object Spark03_SparkSQL_UDAF {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("ageAvg", new MyAvgUDAF())
spark.sql("select ageAvg(age) from user").show
spark.close()
}
class MyAvgUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
override def dataType: DataType = LongType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0)+input.getLong(0))
buffer.update(1, buffer.getLong(1)+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
}
继承Aggregator类
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
object Spark03_SparkSQL_UDAF1 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val df = spark.read.json("datas/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("ageAvg", functions.udaf(new MyAvgUDAF()))
spark.sql("select ageAvg(age) from user").show
spark.close()
}
case class Buff( var total:Long, var count:Long )
class MyAvgUDAF extends Aggregator[Long, Buff, Long]{
override def zero: Buff = {
Buff(0L,0L)
}
override def reduce(buff: Buff, in: Long): Buff = {
buff.total = buff.total + in
buff.count = buff.count + 1
buff
}
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.total = buff1.total + buff2.total
buff1.count = buff1.count + buff2.count
buff1
}
override def finish(buff: Buff): Long = {
buff.total / buff.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
旧版本的强类型
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
object Spark03_SparkSQL_UDAF2 {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.json("datas/user.json")
val ds: Dataset[User] = df.as[User]
val udafCol: TypedColumn[User, Long] = new MyAvgUDAF().toColumn
ds.select(udafCol).show
spark.close()
}
case class User(username:String, age:Long)
case class Buff( var total:Long, var count:Long )
class MyAvgUDAF extends Aggregator[User, Buff, Long]{
override def zero: Buff = {
Buff(0L,0L)
}
override def reduce(buff: Buff, in: User): Buff = {
buff.total = buff.total + in.age
buff.count = buff.count + 1
buff
}
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.total = buff1.total + buff2.total
buff1.count = buff1.count + buff2.count
buff1
}
override def finish(buff: Buff): Long = {
buff.total / buff.count
}
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}