直接上代码
UDF 一进一出
package sparksql_udf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
object sparkSqlUDF {
def main(args: Array[String]): Unit = {
//创建sparkSession
//通过sparkSession创建SparkContext
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("zhiDingSchema").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
//读取数据并加工
//读取时可以spark.read.textFile("D:\\data\\udf.txt")转化为ds
val udfRDD: RDD[String] = sc.textFile("D:\\大数据\\学期文档\\spark\\资料\\udf.txt")
//转化为DS
import spark.implicits._
val udfDS: Dataset[String] = udfRDD.toDS()
//创建临时表
udfDS.createOrReplaceTempView("udf")
//注册udf函数
spark.udf.register("toUpper",(str:String)=>{
//根据业务需求对数据进行加工
str.toUpperCase+" 123"
})
//sql查询 调用udf函数
spark.sql("select value,toUpper(value) from udf").show()
//停止sc、spark
sc.stop()
spark.stop()
}
}
UDAF 多进一出
package sparksql_udf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object sparkSqlUDAF {
def main(args: Array[String]): Unit = {
//创建sparkSession
//通过sparkSession创建SparkContext
val spark: SparkSession = SparkSession.builder().master("local[*]").appName("zhiDingSchema").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
//读取数据并加工
//读取时可以spark.read.textFile("D:\\data\\udf.txt")转化为ds
val dfJson: DataFrame = spark.read.json("D:\\大数据\\学期文档\\spark\\资料\\udaf.json")
//注册临时表
dfJson.createOrReplaceTempView("UDAF")
//注册UDAF函数
spark.udf.register("SalaryAvg",new SalaryAvg)
//计算平均工资的UDAF方法为SalaryAvg
//查询sql 调用UDAF
spark.sql("select SalaryAvg(salary) from UDAF").show()
//spark.sql("select avg(salary) from UDAF").show()
//关闭sc、spark
sc.stop()
spark.stop()
}
//编写计算平均工资的方法SalaryAvg
class SalaryAvg extends UserDefinedAggregateFunction {
//输入的数据类型
override def inputSchema: StructType = {
StructType(List(StructField("input", LongType)))
}
//缓冲区数据类型 两个
override def bufferSchema: StructType = {
StructType(List(StructField("sum", LongType),StructField("total",LongType)))
}
//数据返回的类型
override def dataType: DataType = {
DoubleType
}
//确定是否有相同输出
override def deterministic: Boolean = {
true
}
/**
* list(1,2,3,4).reduce(_+_)
* a=1 b=2
* a=3 b=3
*/
//初始化内部数据结构
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//用于存储总金额
buffer(0) = 0L
//用于存储次数
buffer(1) = 0L
}
//更新数据内部结构,区内计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//计算分区总金额
buffer(0)=buffer.getLong(0)+input.getLong(0)
//计算分区总次数
buffer(1)=buffer.getLong(1)+1
}
//来自不同分区的数据进行合并,全局合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//汇聚所有分区的总金额
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//汇聚所有分区的总次数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
//计算输出数据值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
}
}
}