spark中使用UDF函数
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object SparkUDFDemo {
case class Hobbies(name:String,hobbies:String)
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local").appName("sparkudfdemo").getOrCreate()
val sc: SparkContext = spark.sparkContext
import spark.implicits._
val rdd: RDD[String] = sc.textFile("in/hobbies.txt")
val df: DataFrame = rdd.map(x => x.split(" ")).map(x => Hobbies(x(0), x(1))).toDF
df.registerTempTable("hobbies")
spark.udf.register("hoby_num",(v:String) => v.split(",").size)
spark.sql("select name,hobbies,hoby_num(hobbies) from hobbies").show()
}
}
saprk中使用UDAF函数
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master("local").appName("sparkudaf").getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val df: DataFrame = spark.read.json("in/user.json")
val myUdaf = new MyAgeAvgFunction
spark.udf.register("myAvg",myUdaf)
df.createTempView("userinfo")
val resultDF: DataFrame = spark.sql("select sex, myAvg(age) from userinfo group by sex")
resultDF.printSchema()
resultDF.show()
}
}
class MyAgeAvgFunction extends UserDefinedAggregateFunction() {
override def inputSchema: StructType = {
new StructType().add("age",LongType)
}
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
}
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
spark中使用UDTF函数
import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, StructObjectInspector}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local")
.appName("sparkudtf")
.enableHiveSupport()
.getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val lines: RDD[String] = sc.textFile("in/udtf.txt")
val stuDF: DataFrame = lines.map(_.split("//")).filter(x => x(1).equals("ls")).map(x =>(x(0),x(1),x(2))).toDF("id","name","class")
stuDF.printSchema()
stuDF.show()
stuDF.createOrReplaceTempView("student")
spark.sql("CREATE TEMPORARY FUNCTION MyUDTF AS 'nj.zb.kb09.sql.MyUDTF'")
spark.sql("select MyUDTF(class) from student").show()
}
}
class MyUDTF extends GenericUDTF{
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if (argOIs.length != 1){
throw new UDFArgumentException("有且只能有个参数传入")
}
if (argOIs(0).getCategory != ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException("参数类型不匹配")
}
val fieldNames = new util.ArrayList[String]
val fieldOIs = new util.ArrayList[ObjectInspector]
fieldNames.add("type")
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
override def process(objects: Array[AnyRef]): Unit = {
val strings: Array[String] = objects(0).toString.split(" ")
for (str <- strings) {
val tmp:Array[String] = new Array[String](1)
tmp(0) = str
forward(tmp)
}
}
override def close(): Unit = {}
}