Spark SQL自定义函数UDF、UDAF聚合函数以及开窗函数的使用

一、UDF的使用

1、Spark SQL自定义函数就是可以通过scala写一个类,然后在SparkSession上注册一个函数并对应这个类,然后在SQL语句中就可以使用该函数了,首先定义UDF函数,那么创建一个SqlUdf类,并且继承UDF1或UDF2等等,UDF后边的数字表示了当调用函数时会传入进来有几个参数,最后一个R则表示返回的数据类型,如下图所示:

2、这里选择继承UDF2,如下代码所示:

package com.udf

import org.apache.spark.sql.api.java.UDF2

class SqlUDF extends UDF2[String,Integer,String] {
  override def call(t1: String, t2: Integer): String = {
    t1+"_udf_test_"+t2
  }
}

3、然后在SparkSession生成的对象上通过sparkSession.udf.register进行注册,如下代码所示:

    val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
    val sparkSession=SparkSession.builder().config(conf).getOrCreate()
    //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
    //第三个参数是返回类型
    sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)

4、生成模拟数据,并注册一个临时表,如下代码所示:

var rows=Seq[Row]()
    val random=new Random()
    for(i <- 0 until 10){
      val name="name"+i
      val age=random.nextInt(30)%15+15
      val row=Row(name,age)
      rows +:=row
    }
    val rowsRDD=sparkSession.sparkContext.parallelize(rows)
    val schema=DataTypes.createStructType(Array[StructField](
      DataTypes.createStructField("name",DataTypes.StringType,true),
      DataTypes.createStructField("age",DataTypes.IntegerType,true))
    )

    val df=sparkSession.createDataFrame(rowsRDD,schema)
    df.createOrReplaceTempView("person")
    df.show()

输出 结果如下图所示:

5、在sql语句中使用自定义函数splicing_t1_t2,然后将函数的返回结果定义一个别名name_age,如下代码所示:

val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
sparkSession.sql(sql).show()

输出结果如下:

6、由此可以看到在自定义的UDF类中,想如何操作都可以了,完整代码如下;

package com.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}

import scala.util.Random

object AppUdf {
  def main(args:Array[String]):Unit={
    val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
    val sparkSession=SparkSession.builder().config(conf).getOrCreate()
    //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
    //第三个参数是返回类型
    sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)

    var rows=Seq[Row]()
    val random=new Random()
    for(i <- 0 until 10){
      val name="name"+i
      val age=random.nextInt(30)%15+15
      val row=Row(name,age)
      rows +:=row
    }
    val rowsRDD=sparkSession.sparkContext.parallelize(rows)
    val schema=DataTypes.createStructType(Array[StructField](
      DataTypes.createStructField("name",DataTypes.StringType,true),
      DataTypes.createStructField("age",DataTypes.IntegerType,true))
    )

    val df=sparkSession.createDataFrame(rowsRDD,schema)
    df.createOrReplaceTempView("person")

    val sql="SELECT name,age,splicing_t1_t2(name,age) name_age FROM person"
    sparkSession.sql(sql).show()

    sparkSession.close()
  }
}

二、无类型的用户自定于聚合函数:UserDefinedAggregateFunction

1、它是一个接口,需要实现的方法有:

class AvgAge extends UserDefinedAggregateFunction {
  //设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
  override def inputSchema: StructType = ???
  //指定缓冲数据的字段与类型
  override def bufferSchema: StructType = ???
  //指定数据的返回类型
  override def dataType: DataType = ???
  //指定是否是确定性,对输入数据进行一致性检验,是一个布尔值,当为true时,表示对于同样的输入会得到同样的输出
  override def deterministic: Boolean = ???
  //initialize用户初始化缓存数据
  override def initialize(buffer: MutableAggregationBuffer): Unit = ???
  //当有新的输入数据时,update就会更新缓存变量
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
  //将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
  //一个计算方法,用于计算我们的最终结果
  override def evaluate(buffer: Row): Any = ???
}

这是一个计算平均年龄的自定义聚合函数,实现代码如下所示:

package com.udf

import java.math.BigDecimal

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType}

/**
 * 用于计算平均年龄的聚合函数
 */
class AvgAge extends UserDefinedAggregateFunction {
  /**
   * 设置输入数据的类型,指定输入数据的字段与类型,它与在生成表时创建字段时的方法相同
   * 比如计算平均年龄,输入的是age这一列的数据,注意此处的age名称可以随意命名
   * @return
   */
  override def inputSchema: StructType = DataTypes.createStructType(Array[StructField](DataTypes.createStructField("age",DataTypes.IntegerType,true)))

  /**
   * 指定缓冲数据的字段与类型,相当于中间变量
   * 由于要计算平均值,首先要计算出总和与个数才能计算平均值,因此需要进来一个值就要累加并计数才能计算出平均值
   * 所以要定义两个变量作为累加和以及计数的变量
   * @return
   */
  override def bufferSchema: StructType = DataTypes.createStructType(Array[StructField](
    DataTypes.createStructField("sum",DataTypes.DoubleType,true),
    DataTypes.createStructField("count",DataTypes.IntegerType,true)
  ))
  //指定数据的返回类型,由于平均值是double类型,因此定义DoubleType
  override def dataType: DataType = DataTypes.DoubleType
  /**
   * 设置该函数是否为幂等函数
   * 幂等函数:即只要输入的数据相同,结果一定相同
   * true表示是幂等函数,false表示不是
   * @return
   */
  override def deterministic: Boolean = true

  /**
   * initialize用于初始化缓存变量的值,也就是初始化bufferSchema函数中定义的两个变量的值sum,count
   * 其中buffer(0)就表示sum值,buffer(1)就表示count的值,如果还有第3个,则使用buffer(3)表示
   * @param buffer
   */
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0,0.0) //或使用buffer(0)=0.0
    buffer.update(1,0) //或使用buffer(1)=0
  }

  /**
   * 当有一行数据进来时就会调用update一次,有多少行就会调用多少次,input就表示在调用自定义函数中有多少个参数,最终会将
   * 这些参数生成一个Row对象,在使用时可以通过input.getString或inpu.getLong等方式获得对应的值
   * 缓冲中的变量sum,count使用buffer(0)或buffer.getDouble(0)的方式获取到
   * @param buffer
   * @param input
   */
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val sum=buffer.getDouble(0)
    val count=buffer.getInt(1)
    buffer.update(0,sum+input.getInt(0).toDouble)
    buffer.update(1,count+1)
  }

  /**
   * 将更新的缓存变量进行合并,有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行
   * 其中buffer1是本节点上的缓存变量,而buffer2是从其他节点上过来的缓存变量然后转换为一个Row对象,然后将buffer2
   * 中的数据合并到buffer1中去即可
   * @param buffer1
   * @param buffer2
   */
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val sum1=buffer1.getDouble(0)
    val count1=buffer1.getInt(1)
    val sum2=buffer2.getDouble(0)
    val count2=buffer2.getInt(1)
    buffer1.update(0,sum1+sum2)
    buffer1.update(1,count1+count2)
  }

  /**
   * 一个计算方法,用于计算我们的最终结果,也就相当于返回值
   * @param buffer
   * @return
   */
  override def evaluate(buffer: Row): Any = {
    val bd = new BigDecimal(buffer.getDouble(0)/buffer.getInt(1).toDouble)
    bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
  }
}

2、注册该类,并指定到一个自定义函数中,如下图所示:

3、在表中加一列字段id,通过GROUP BY进行分组计算,如

4、在sql语句中使用group_age_avg,如下图所示:

输出结果如下图所示:

5、完整代码如下:

package com.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}

import scala.util.Random

object AppUdf {
  def main(args:Array[String]):Unit={
    val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
    val sparkSession=SparkSession.builder().config(conf).getOrCreate()
    //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
    //第三个参数是返回类型
    sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
    //UDAF不用设置返回类型,因此使用两个参数即可
    sparkSession.udf.register("group_age_avg",new AvgAge)
    var rows=Seq[Row]()
    val random=new Random()
    for(i <- 0 until 10){
      val name="name"+i
      val age=random.nextInt(30)%15+15
      val row=Row(random.nextInt(2),name,age)
      rows +:=row
    }
    val rowsRDD=sparkSession.sparkContext.parallelize(rows)
    val schema=DataTypes.createStructType(Array[StructField](
      DataTypes.createStructField("id",DataTypes.IntegerType,true),
      DataTypes.createStructField("name",DataTypes.StringType,true),
      DataTypes.createStructField("age",DataTypes.IntegerType,true))
    )

    val df=sparkSession.createDataFrame(rowsRDD,schema)
    df.createOrReplaceTempView("person")
    df.show()

    val sql="SELECT id, group_age_avg(age) avg_age FROM person GROUP BY id"
    sparkSession.sql(sql).show()

    sparkSession.close()
  }
}

三、类型安全的用户自定于聚合函数:Aggregator

1、它是一个接口,需要继承与Aggregator,而Aggregator有3个参数,分别是IN,BUF,OUT,IN表示输入的值是什么,可以是一个自定类对象包含多个值,也可以是单个值,BUF就是需要用来缓存值使用的,如果需要缓存多个值也需要定义一个对象,而返回值也可以是一个对象返回多个值,需要实现的方法有:

package com.udf

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.expressions.Aggregator

case class DataBuf(var sum:Double,var count:Int)
object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{
  /**
   * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
   * @return
   */
  override def zero: DataBuf = ???

  /**
   * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
   * @param b
   * @param a
   * @return
   */
  override def reduce(b: DataBuf, a: Int): DataBuf = ???

  /**
   * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
   * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
   * @param b1
   * @param b2
   * @return
   */
  override def merge(b1: DataBuf, b2: DataBuf): DataBuf = ???

  /**
   * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
   * 返回值可以是一个对象
   * @param reduction
   * @return
   */
  override def finish(reduction: DataBuf): Double = ???

  /**
   *  缓冲数据编码方式
   * @return
   */
  override def bufferEncoder: Encoder[DataBuf] = ???

  /**
   *  最终数据输出编码方式
   * @return
   */
  override def outputEncoder: Encoder[Double] = ???
}

2、具体实现如下代码所示:

package com.udf

import java.math.BigDecimal

import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
case class DataBuf(var sum:Double,var count:Int)
object AvgAgeAggregator extends Aggregator[Int,DataBuf,Double]{

  /**
   * 相当于UserDefinedAggregateFunction中的initialize函数,用于初始化DataBuf对象的值,此DataBuf是自定义类型的
   * @return
   */
  override def zero: DataBuf = DataBuf(0.0,0)

  /**
   * reduce函数相当于UserDefinedAggregateFunction中的update函数,当有新的数据a时,更新中间数据b
   * @param b
   * @param a
   * @return
   */
  override def reduce(b: DataBuf, a: Int): DataBuf = {
    b.count+=1
    b.sum+=a.toDouble
    b
  }

  /**
   * merge函数相当于UserDefinedAggregateFunction中的merge函数,对两个值进行 合并,
   * 因为有可能每个缓存变量的值都不在一个节点上,最终是要将所有节点的值进行合并才行,将b2中的值合并到b1中
   * @param b1
   * @param b2
   * @return
   */
  override def merge(b1: DataBuf, b2: DataBuf): DataBuf = {
    b1.sum+=b2.sum
    b1.count+=b2.count
    b1
  }

  /**
   * finish相当于UserDefinedAggregateFunction中的evaluate,是一个计算方法,用于计算我们的最终结果,也就相当于返回值
   * 返回值可以是一个对象
   * @param reduction
   * @return
   */
  override def finish(reduction: DataBuf): Double = {
    val bd = new BigDecimal(reduction.sum/reduction.count.toDouble)
    bd.setScale(2, BigDecimal.ROUND_HALF_UP).doubleValue//保留两位小数
  }

  /**
   *  缓冲数据编码方式,如果Encoder中指定的类型时对象,则设置为product,如果是具体的类型,则需设置为具体的类型
   * @return
   */
  override def bufferEncoder: Encoder[DataBuf] = Encoders.product

  /**
   *  最终数据输出编码方式,如果Encoder中指定的类型,则设置为具体的类型,比如Double则设置为scalaDouble
   * @return
   */
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

3、而使用此聚合函数就不能通过注册函数来使用了,需要通过Dataset对象的select来使用,如下图所示:

执行结果如下图所示:

因此无类型的用户自定于聚合函数:UserDefinedAggregateFunction和类型安全的用户自定于聚合函数:Aggregator之间的区别是

(1)UserDefinedAggregateFunction不能够带类型而Aggregator是可以带类型的。

(2)使用方法不同UserDefinedAggregateFunction通过注册可以在DataFram的sql语句中使用,而Aggregator必须是在Dataset上使用。

四、开窗函数的使用

1、在Spark 1.5.x版本以后,在Spark SQL和DataFrame中引入了开窗函数,其中比较常用的开窗函数就是row_number该函数的作用是根据表中字段进行分组,然后根据表中的字段排序;其实就是根据其排序顺序,给组中的每条记录添加一个序号;且每组的序号都是从1开始,可利用它的这个特性进行分组取top-n。它是放在select子句中的,其格式为:

ROW_NUMBER() OVER (PARTITION BY area ORDER BY click_count DESC) rank 

首先可以,在SELECT查询时,使用row_number()函数,其次row_number()函数后面先跟上OVER关键字,然后括号中,是PARTITION BY,也就是说根据哪个字段进行分组,其次是可以用ORDER BY进行组内排序, 然后row_number()就可以给每个组内的行,一个组内行号,然后rank就是每一组的行号

2、使用方法的sql语句为:

SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc

意思是在sql语句中加一个rank字段,该字段记录了以id为分组,在组内按照age升序排序,并记录行号,最后先按照id降序排序,如果id相同则按照rank降序排序

3、代码如下:

package com.udf

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{DataTypes, StructField}

import scala.util.Random

object AppUdf {
  def main(args:Array[String]):Unit={
    val conf=new SparkConf().setAppName("AppUdf").setMaster("local")
    val sparkSession=SparkSession.builder().config(conf).getOrCreate()
    //指定函数名为:splicing_t1_t2 此函数名只有通过udf.register注册过之后才能够被使用,第二个参数是继承与UDF的类
    //第三个参数是返回类型
    sparkSession.udf.register("splicing_t1_t2",new SqlUDF,DataTypes.StringType)
    //UDAF不用设置返回类型,因此使用两个参数即可
    sparkSession.udf.register("group_age_avg",new AvgAge)
    var rows=Seq[Row]()
    val random=new Random()
    for(i <- 0 until 10){
      val name="name"+i
      val age=random.nextInt(30)%15+15
      val row=Row(random.nextInt(2),name,age)
      rows +:=row
    }
    val rowsRDD=sparkSession.sparkContext.parallelize(rows)
    val schema=DataTypes.createStructType(Array[StructField](
      DataTypes.createStructField("id",DataTypes.IntegerType,true),
      DataTypes.createStructField("name",DataTypes.StringType,true),
      DataTypes.createStructField("age",DataTypes.IntegerType,true))
    )

    val df=sparkSession.createDataFrame(rowsRDD,schema)
    df.createOrReplaceTempView("person")
    df.show()

    val sql="SELECT id,name,age,row_number() OVER (PARTITION BY id ORDER BY age) rank FROM person ORDER BY id desc,rank desc"
    sparkSession.sql(sql).show()
    sparkSession.close()
  }
}

输出结果如下:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值