自定义UTF函数 弱类型
package SparkSql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object Session_UDAF_1 {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val sc: SparkConf = new SparkConf().setMaster("local[*]").setAppName("DataFrame")
val spark: SparkSession = SparkSession.builder().config(sc).getOrCreate()
//TODO 导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字
// (这里并没有用到RDD,DataFrame,DataSet类型转换,但一般建议出现SparkSession对象时,都要加上
import spark.implicits._
val frame: DataFrame = spark.read.
format("csv").
option("header", "true").option("inferSchema", true.toString).
load("F:\\MySQL\\python招聘.csv")
//DataFrame(表结构)对象创建一个临时视图
frame.createTempView("python")
//查看frame数据集字段名RecruitPostName,RecruitPostId前5行
spark.sql("select RecruitPostName,RecruitPostId from python limit 5").show()
//创建 UTF函数对象
val function = new MyAgeFunction()
//绑定临时函数
spark.udf.register("avg_RecruitPostId", function)
//使用自定义UTF函数求字段RecruitPostId的平均值
spark.sql("select avg_RecruitPostId(RecruitPostId) from python").show()
spark.stop()
}
}
//声明用户自定义聚合函数(弱类型)
// 1) 继承UserDefinedAggregateFunction
// 2)实现方法
class MyAgeFunction extends UserDefinedAggregateFunction {
//函数的输入结构
override def inputSchema: StructType = {
new StructType().add("Name", LongType)
}
// 计算时函数的的数据结构
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
//函数返回的数据类型
override def dataType: DataType = DoubleType
//函数是否稳定
override def deterministic: Boolean = true
//计算之间函数 缓冲区的初始化 给sum,count一个初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//根据查询结果更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0) //单节点sum求和
buffer(1) = buffer.getLong(1) + 1 // 单节点count计数
}
//将多个节点 合并缓冲区数据(其实类似分区,分区内计算完毕,分区间聚合分区内的结果)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) //多节点sum求和
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) //多节点count计数
}
// 将最终计算结果返回
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}
注意,如果您的文件内容表头(字段名)有特殊字符(空格之类,我使用的是spark2.2.1版本,有出错请务必去除特殊字符,
不过好像2.3版本之后有改进,不知道是不是真的,反正我这块表头有特殊字符,就报错)
自定义UTAF函数 强类型
package SparkSql
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
object Session_UDAF_class {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val sc: SparkConf = new SparkConf().setMaster("local[*]").setAppName("DataSet")
val spark: SparkSession = SparkSession.builder().config(sc).getOrCreate() //导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字
val frame: DataFrame = spark.read.
format("csv").
option("header", "true").
option("inferSchema", true.toString).
load("E:\\谷歌文件\\archive\\restaurant-1-orders.csv")
//TODO 导入隐式转换 这里的spark不是包的名字,而是SparkSession对象的名字
import spark.implicits._
//Order_ID Order_Date Item_Name Quantity Product_Price Total_products
frame.createTempView("test") //创建一个临时视图
spark.sql("select Order_ID,Product_Price from test limit 5").show()
//TODO DataFrame类型转RDD算子 注意RDD算子类型,DataFrame类型转RDD类 类型都将是ROW(行)
val rdd: RDD[Row] = frame.rdd
val resultRDD: RDD[Array[Any]] = rdd.map {
case x => Array(x.getInt(0), x.getString(1), x.getString(2),
x.getInt(3), x.getDouble(4), x.getInt(5))
}
//创建聚合函数对象
val clazz = new MyFunctionClass()
//将聚合函数转换为查询列
val value: TypedColumn[OrderBean, Double] = clazz.toColumn.name("avg_Product_Price")
// 将Dataframe类型转换为DataSet类型
val OrderDS: Dataset[OrderBean] = frame.as[OrderBean]
OrderDS.select(value).show()
//释放资源
spark.stop()
}
}
case class OrderBean(Order_ID: Long, Order_Date: String, Item_Name: String,
Quantity: Long, Product_Price: Double, Total_products: Long)
case class AvgBean(var sum: Double, var count: Int)
//声明用户自定义聚合函数(强类型)
// 1) 继承
// 2)实现方法
class MyFunctionClass extends Aggregator[OrderBean, AvgBean, Double] {
//初始化
override def zero: AvgBean = {
AvgBean(0.0, 0)
}
//聚合函数,与UserDefinedAggregateFunction的函数update原理相同,聚合函数更好理解而已
override def reduce(b: AvgBean, a: OrderBean): AvgBean = {
b.sum = b.sum + a.Product_Price //计算Product_Price这一列的总和(单节点)
b.count = b.count + 1
b
}
//与UserDefinedAggregateFunction的merge实现也是 合并多个节点数据
override def merge(b1: AvgBean, b2: AvgBean): AvgBean = {
b2.sum = b1.sum + b2.sum
b2.count = b1.count + b2.count
b2
}
//完成计算
override def finish(reduction: AvgBean): Double = reduction.sum / reduction.count
//以下是固定写法
override def bufferEncoder: Encoder[AvgBean] = Encoders.product //用户自定义类型写法
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble //scala中自带的类型写法
}
总结: RDD算子难用,将其转换为DataFrame类型(二维表结构),Spark提供sql语句来很轻松的使用DataFrame,相对来说简单许多,至于DataSet是将DataFrame类型中的字段名封装成类(我们一般是使用样例类),类就有属性(字段名),毕竟在面向对象程序语言中,我们还是更希望通过对象来访问数据。