UDF: user defined function

UDAF
package com.bjsxt.scalaspark.sql.UDF_UDAF
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
*UserDefinedAggregateFunction 用户自定义聚合函数 抽象类
*/
class MyUDAF extends UserDefinedAggregateFunction {
//输入数据的类型
def inputSchema: StructType = {
DataTypes.createStructType(Array(DataTypes.createStructField("uuuu", StringType, true)))
}
/**
* 为每个分组的数据执行初始化值
* 两个部分的初始化:
* 1.在map端每个RDD分区内,在RDD每个分区内 按照group by 的字段分组,每个分组都有个初始化的值
* 2.在reduce 端给每个group by 的分组做初始值
*/
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
// 每个组,有新的值进来的时候,进行分组对应的聚合值的计算
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0)+1
}
// 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
}
// 聚合操作时,所处理的数据的类型
def bufferSchema: StructType = {
DataTypes.createStructType(Array(DataTypes.createStructField("QQQQ", IntegerType, true)))
}
// 最后返回一个最终的聚合值要和dataType的类型一一对应
def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
// 最终函数返回值的类型
def dataType: DataType = {
DataTypes.IntegerType
}
//多次运行 相同的输入总是相同的输出,确保一致性
def deterministic: Boolean = {
true
}
}
object UDAF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").appName("UDAF").getOrCreate()
val nameList: List[String] = List[String]("zhangsan", "lisi", "lisi", "lisi", "lisi", "wangwu", "zhangsan", "lisi", "zhangsan", "wangwu")
import spark.implicits._
val frame: DataFrame = nameList.toDF("name")
frame.createOrReplaceTempView("students")
//select name ,count(*) from table group by name
/**
* 注册UDAF函数
*
*/
spark.udf.register("NAMECOUNT",new MyUDAF())
spark.sql("select name,NAMECOUNT(name) as count from students group by name").show(100)
}
}
开窗函数
row_number() over (partitin by XXX order by XXX)
如果SQL语句里面使用到了开窗函数,那么这个SQL语句必须使用HiveContext来执行,HiveContext默认情况下在本地无法创建。在MySql8之后也增加了开窗函数。
- hive
/**
* over 窗口函数
* row_number() over(partition by xx order by xx) as rank
* rank 在每个分组内从1开始
*/
object OverFunctionOnHive {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("over").enableHiveSupport().getOrCreate()
spark.sql("use spark")
spark.sql("create table if not exists sales (riqi string,leibie string,jine Int) " + "row format delimited fields terminated by '\t'")
spark.sql("load data local inpath '/root/test/sales' into table sales")
/**
* rank 在每个组内从1开始
* 5 A 200 --- 1
* 3 A 100 ---2
* 4 A 80 ---3
* 7 A 60 ---4
*
* 1 B 100 ---1
* 8 B 90 ---2
* 6 B 80 ---3
* 1 B 70 ---4
*/
val result = spark.sql(
"select"
+" riqi,leibie,jine "
+ "from ("
+ "select "
+"riqi,leibie,jine,row_number() over (partition by leibie order by jine desc) rank "
+ "from sales) t "
+ "where t.rank<=3")
result.write.mode(SaveMode.Append).saveAsTable("salesResult")
result.show(100)
}
}
- mysql
import java.util.Properties
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* 在mysql中使用开窗函数
* row_number() over(partition by xx order by xx) as rank
* rank 在每个分组内从1开始
* 需要创建表 sales:
* create table sales (riqi VARCHAR(20),leibie VARCHAR(20),jine VARCHAR(20))
* 使用Execl 可以导入数据,复制粘贴即可。
*
*/
object OverFunctioOnMySQL {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").appName("overfunctioonmysql").getOrCreate()
val properties = new Properties()
properties.setProperty("user", "root")
properties.setProperty("password", "123456")
val sql = "select riqi,leibie,jine from " +
"(select riqi,leibie,jine, row_number() over (partition by leibie order by jine desc) as 'rank' " +
"from sales) t where t.rank<=3"
val person: DataFrame = spark.read.jdbc("jdbc:mysql://192.168.179.101:3306/spark",s"($sql) T",properties)
person.show()
}
}

2834

被折叠的 条评论
为什么被折叠?



