UDF
import org.apache.spark.sql.SparkSession
object SparkUDFDemo {
case class Hobbies(name:String,hobbies: String)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]")
.appName("sparkudfdemo")
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
val rdd = sc.textFile("in/hobbies")
val df = rdd.map(x=>x.split(" ")).map(x=>Hobbies(x(0),x(1))).toDF()
//df.show()
df.registerTempTable("hobbies")
spark.udf.register("hoby_num",
(v:String)=>v.split(",").size)
val frame = spark.sql(""+"select name,hobbies,hoby_num(hobbies) as hobynum from hobbies")
frame.show()
}
}
UDAF
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]")
.appName("sparkudafdemo")
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
val df = spark.read.json("in/user")
///df.show()
//创建并自定义udaf函数
val myUdaf=new MyAgeAvgFunction
spark.udf.register("myAvgAge",myUdaf)
df.createTempView("userinfo")
val resultDF: DataFrame = spark.sql("select sex,myAvgAge(age) from userinfo group by sex")
resultDF.show()
}
}
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输入数据结构
override def inputSchema: StructType={
new StructType().add("age",LongType)
// StructType(StructField("age",LongType)::Nil)
}
//缓存区数据结构
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
}
//聚合函数返回值数据结构
override def dataType: DataType = DoubleType
//聚合函数是否是幂相等,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}
//给聚合函数传入一条新数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//总年龄数
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//部个数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}
}
UDTF
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import java.util
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.spark.sql.SparkSession
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]")
.appName("sparkudtfdemo")
.enableHiveSupport()
.getOrCreate()
import spark.implicits._
val sc = spark.sparkContext
val lines = sc.textFile("in/udtf")
val stuDF = lines.map(_.split("//"))
.filter(x => x(1).equals("ls"))
.map(x => (x(0), x(1), x(2))).toDF("id", "name", "class")
//stuDF.show()
stuDF.createOrReplaceTempView("student")
spark.sql("CREATE TEMPORARY FUNCTION myUDTF AS 'myUDTF'")
val resultDF = spark.sql("select myUDTF(class) from student")
resultDF.show()
}
}
class myUDTF extends GenericUDTF {
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if (argOIs.length!=1){
throw new UDFArgumentException("有且只能有一个参数传入")
}
if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames=new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
fieldNames.add("type")
//定义输出列字段类型
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
//传入Hadoop scala kafa hive hbase Oozie
// 输出 Haodoop
// scala
// kafka
// hive
// hbase
// Oozie
override def process(objects: Array[AnyRef]): Unit = {
val strings = objects(0).toString.split(" ")
println(strings)
for (str<-strings){
val tmp = new Array[String](1)
tmp(0)=str
forward(tmp)
}
}
override def close(): Unit = {
}
}