简述:
开发过程中,有时候函数满足不了我们的需求,我们需要自己去定义函数使用。在spark中,有三种自定义函数,分别为UDF,UDAF,UDTF。
UDF:一对一
UDAF:多对一
UDTF:一对多
UDF函数实例:
hobbies.txt文件内容
alice jogging,Coding,cooking
lina travel,dance
需求:求出每个人hobbies的数量
操作代码:
val conf: SparkConf = new SparkConf().setAppName("innserdemo").setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val sc: SparkContext = spark.sparkContext
import spark.implicits._
//文件路径
val hobbyDF: DataFrame = sc.textFile("in/hobbies.txt")
.map(x => x.split(" "))
.map(x => Hobbies(x(0), x(1))).toDF()
hobbyDF.createOrReplaceTempView("hobby")
spark.udf.register("hobby_num",(x:String)=>{x.split(",").size})
import org.apache.spark.sql.functions
val hobby_num: UserDefinedFunction = functions.udf((hobbies: String) => {
hobbies.split(",").size
})
val newhobbyDF: DataFrame = hobbyDF.withColumn("hobbynum", hobby_num($"hobbies"))
newhobbyDF.printSchema()
newhobbyDF.show(false)
运行结果:
UDAF函数实例:
自定义函数UDAF 继承 UserDefinedAggregateFunction
需求:根据性别分组求平均年龄
操作代码:
val conf: SparkConf = new SparkConf().setAppName("innserdemo").setMaster("local[*]")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val sc: SparkContext = spark.sparkContext
val students: Seq[Student] = Seq(
Student(1, "zhangsan", "F", 22),
Student(2, "lisi", "M", 38),
Student(3, "wangwu", "M", 13),
Student(4, "zhaoliu", "F", 17),
Student(5, "songba", "M", 32),
Student(6, "sunjiu", "M", 16),
Student(7, "qianshiyi", "F", 17),
Student(8, "yinshier", "F", 15),
Student(9, "fangshisan", "M", 12),
Student(10, "yeshisan", "F", 11),
Student(11, "ruishiyi", "F", 26),
Student(12, "chenshier", "M", 28)
)
val frame: DataFrame = spark.createDataFrame(students)
frame.printSchema()
// import org.apache.spark.sql.functions._
spark.udf.register("myAvg",new MyAgeAvgFunction)
frame.createOrReplaceTempView("students")
val resultDF: DataFrame = spark.sql(
"select gender,myAvg(age) from students group by gender"
)
resultDF.printSchema()
resultDF.show(false)
自定义函数MyAgeAvgFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输入数据的数据结构
override def inputSchema: StructType = {
// new StructType().add("age",LongType)
StructType(StructField("age",LongType) :: Nil)
}
//在缓冲区内的数据结构 ageSum(1000) ageNum(200)
//sum 用来记录 所有年龄值相加的总和 43 + 52 + 61 + 78 = 234 => sum
//count 用来记录相加的总和 1 + 1 + 1 + 1 = 4 => count
override def bufferSchema: StructType = {
// new StructType().add("sum",LongType).add("count",LongType)
StructType(StructField("num",LongType) :: StructField("count",LongType) :: Nil)
}
//定义当前函数返回值的类型 sum/count
override def dataType: DataType = DoubleType
// 聚合函数幂等
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L //记录传入所有用户年龄相加的总和
buffer(1)=0L //记录传入所有用户年龄的个数
}
//传入一条新数据后需要进行处理
//将Row(63)对象中的值取出与buffer(0)相加
//buffer(1)数据个数加1
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
//合并各分区内的数据
//例如 p1(321,6) p2(128,2) p3(219,3)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//计算年龄相加总和
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//总人数
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1).toDouble
}
}
运行结果:
UDTF函数实例:
自定义函数UDTF继承GenericUDTF
UDTF.txt文件内容
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop
需求:求出某一位同学课程的信息
操作代码:
val conf: SparkConf = new SparkConf().setAppName("UDTFDemo").setMaster("local[*]")
val spark: SparkSession = SparkSession.builder()
.config(conf)
.config("hive.metastore.uris", "thrift://192.168.91.135:9083")
.enableHiveSupport()
.getOrCreate()
val sc: SparkContext = spark.sparkContext
import spark.implicits._
val rdd: RDD[String] = sc.textFile("in/UDTF.txt")
val rdd2: RDD[(String, String, String)] = rdd.map(x => {
x.split("//")
}).filter(x => x(1).equals("ls"))
.map(x => (x(0), x(1), x(2)))
val frame: DataFrame = rdd2.toDF("id", "name", "class")
frame.createOrReplaceTempView("udtftable")
spark.sql("create temporary function Myudtf as 'day12_13.MyUDTF'")
spark.sql("select Myudtf(class) from udtftable").show(false)
自定义函数MyUDTF:
class MyUDTF extends GenericUDTF{
/*
传入内容 Hadoop scala kafka hive hbase Oozie
输出 head type string
Hadoop
scala
kafka
hive
hbase
Oozie
*/
override def process(objects: Array[AnyRef]): Unit = {
val strings: Array[String] = objects(0).toString.split(" ")
for(str<-strings){
val temp = new Array[String](1)
temp(0)=str
forward(temp)
}
}
override def close(): Unit = {
}
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
val fieldName = new java.util.ArrayList[String]()
val fieldOIS = new java.util.ArrayList[ObjectInspector]()
//定义输出字段的类型
fieldName.add("type")
fieldOIS.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldName,fieldOIS)
}
}
运行结果: