Spark SQL创建UDF&UDAF案例详解

UDF:即用户自定义函数
UDAF:即用户自定义聚合函数

#UDF案例一:

import org.apache.spark.sql.{DataFrame, SparkSession}

//用户自定义函数
object _UdfDemo1 {
    def main(args: Array[String]): Unit = {
        val spark: SparkSession = SparkSession
          .builder()
          .master("local[*]")
          .appName(" ")
          .getOrCreate()

        val df: DataFrame = spark.read.json("sql/emp.json")
        df.createTempView("emp")

        //1.定义一个方法
        def fun1(word:String)={
            word.length
        }
        //2.注册函数 fun1 _ :将方法转成函数
//        spark.udf.register("mylength",fun1 _)
        //不用定义函数,直接使用匿名函数的写法
        spark.udf.register("myudf1",{word:String=>word.length})


        //3.SQ风格的写法
        spark.sql("select ename,myudf1(ename) as lg from emp where myudf1(ename)>4").show()


        spark.stop()

    }
}

#UDF案例二:

import org.apache.spark.sql.{DataFrame, SparkSession}

//用户自定义函数
object _UdfDemo2 {
    def main(args: Array[String]): Unit = {
        val spark: SparkSession = SparkSession
          .builder()
          .master("local[*]")
          .appName(" ")
          .getOrCreate()

        val df: DataFrame = spark.read.json("sql/emp.json")
        df.createTempView("emp")

        //需求分析:显示每个员工的工资等级,
        // sal>3000 显示level3
        // sal>1500 显示level2
        // 其他的显示level1

        //1.自定义一个fun2函数
        def fun2(num:Double)={
            if (num>3000)
                "level3"
            else if (num>1500)
                "level3"
            else
                "level1"
        }

        //2.注册自定义函数
        spark.udf.register("myudf2",fun2 _)

        //3.写SQL风格语句,使用自定义函数
        val sql =
            """
              |select
              |count(1),
              |myudf2(sal)
              |from
              |emp
              |group by
              |myudf2(sal)
              |""".stripMargin

        //4.传入sql,显示数据
        spark.sql(sql).show()
        spark.stop()

    }
}

#UDAF案例一:

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

//用户自定义聚合函数
object _UdafDemo3 {
    def main(args: Array[String]): Unit = {
        val spark: SparkSession = SparkSession
          .builder()
          .master("local[*]")
          .appName(" ")
          .getOrCreate()

        val df: DataFrame = spark.read.json("sql/emp.json")
        df.createTempView("emp")
        /**
         * 名词解释:
         * udf:一对一关系 传入一行输出一行
         * dfaf:多对一关系 传入多行输出一行
         */

        //需求:
        //查询每个部门的平均工资
        //2.注册函数
        spark.udf.register("myudaf1",new MyUDAF)

        //3.编写sql语句
        val sql =
            """
              |select
              |deptno,
              |myudaf1(sal)
              |from
              |emp
              |group by deptno
              |""".stripMargin

        //4,显示结果
        spark.sql(sql).show()

        spark.stop()
    }
}
/**
 * 1.
 * 需要自定一个类型,用来继承UDAF的相关类型
 * 需要根据需求重写继承类的方法
 */
class MyUDAF extends UserDefinedAggregateFunction{
    //--描述进入函数的参数的类型
    override def inputSchema: StructType = StructType(
        Array(
            StructField("num",DoubleType)
        )
    )

    //--描述计算过程中涉及到的变量的类型
    override def bufferSchema: StructType = StructType(
        Array(
            StructField("sum",DoubleType),
            StructField("count",LongType)
        )
    )

    //--描述计算结果的类型
    override def dataType: DataType = DoubleType

    //--描述函数的稳定性
    override def deterministic: Boolean = true

    //--对计算过程中涉及到的两个变量进行初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //buffer的第一个元素表示sum
        buffer(0) = 0D
        //buffer的第二个元素表示count
        buffer(1) = 0L
    }

    /**
     * 预聚合操作
     * @param buffer    参数buffer指的使当前的缓存中的数据,相当this
     * @param input     表示刚刚进入函数内的这一条记录
     */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        //update 用于更新数据
        buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
        //+1操作
        buffer.update(1,buffer.getLong(1)+1)
    }

    /**
     * 分区间的合并操作
     * @param buffer1   要返回的数据
     * @param buffer2   另外一个分区的数据
     */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
        buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
    }

    //用于计算结果
    override def evaluate(buffer: Row): Any ={
        buffer.getDouble(0)/buffer.getLong(1)
    }
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值