UDF:计算单词的长度
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
object udfDemo {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder().appName("udfDemo").master("local[2]").getOrCreate()
val sc = session.sparkContext
//构造模拟数据
val names=Array("Leo","Marry","Jack","Tom")
val namesRdd = sc.parallelize(names,4)
val nameRowRdd = namesRdd.map(name=>Row(name))
val structType = StructType(Array(StructField("name",StringType,true)))
//创建dataframe,创建是由SparkSession创建的
val namesDF = session.createDataFrame(nameRowRdd,structType)
//注册临时表
namesDF.createTempView("names")
//定义和注册自定义函数
//定义函数:自己写匿名函数
//注册函数:udf.register,注册函数使用的是SparkSession
session.udf.register("strLen",(str:String)=>str.length)
//接下来只用的自定义函数
session.sql("select name,strLen(name) from names").collect().foreach(println)
}
}
udaf:统计单词出现的次数
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
object udfDemo1 {
def main(args: Array[String]): Unit = {
val sp = SparkSession.builder().appName("udfDemo").master("local[*]").getOrCreate()
val sparkContext = sp.sparkContext
val rdd0 = sparkContext.parallelize(Array("tom","jerry","mary","mary","tom")).map(name=>Row(name))
val structType = StructType(Array(StructField("name",StringType,true)))//类型和是否为空
val namedf = sp.createDataFrame(rdd0,structType)
//注册表
namedf.createTempView("v_name")
//注册自定义的函数
sp.udf.register("myagg",new UDFA)
//使用
val res1= sp.sql("select name,myagg(name) from v_name group by name").collect().foreach(println)
}
}
class UDFA extends UserDefinedAggregateFunction{
//指的是,输入数据类型
override def inputSchema: StructType = {
StructType(Array(StructField("str",StringType,true)))
}
//指的是,中间进行聚合时,所处理数据的类型
override def bufferSchema: StructType = {
StructType(Array(StructField("count",IntegerType,true)))
}
//最后聚合结果的返回类型--》函数返回值的类型
override def dataType: DataType = {
IntegerType
}
//数据的一致性,一般为true
override def deterministic: Boolean = true
//聚合函数的初始化值,局部聚合使用,为每一个分组的数据执行初始化操作
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}
// 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
// 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getAs[Int](0)+1
}
// 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update
// 但是,最后一个分组,在各个节点上的聚合值,要进行merge,也就是合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
}
// 最后返回一个最终的聚合值
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
}