sparksql中RDD、Dataframe、DateSet的创建于相互转换
package spark_sql
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
case class Emp(name: String, age: Long)
object spark_sql_json extends App {
val ss: SparkSession = SparkSession.builder().master("local[*]").appName("sql").getOrCreate()
import ss.implicits._
val df: DataFrame = ss.read.json("datas/a.json").toDF()
println("df")
df.show()
val ds: Dataset[Emp] = df.as[Emp]
println("ds")
ds.show()
val df_rdd: RDD[Row] = df.rdd
val rdd: RDD[Emp] = ds.rdd
println("ds.rdd")
rdd.collect().foreach(println)
println("rdd.toDS")
rdd.toDS()
ss.stop()
}
自定义UDAF函数
弱类型:继承UserDefinedAggregateFunction
package sparkSql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
case class user(name: String, age: Long)
object sparkSqlUdf {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val df: DataFrame = spark.read.json("datas/sql")
spark.udf.register("udaf", new myUDAF())
df.createOrReplaceTempView("user")
spark.sql("select name , udaf(age)+1 as newAge from user group by name").show()
spark.stop()
}
}
class myUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
StructType(Array(StructField("age", DoubleType))
)
}
override def bufferSchema: StructType = {
StructType(Array(StructField("total", DoubleType), StructField("count", DoubleType)))
}
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0D)
buffer.update(1, 0D)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getDouble(0) + input.getDouble(0))
buffer.update(1, buffer.getDouble(1) + 1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0))
buffer1.update(1, buffer1.getDouble(1) + buffer2.getDouble(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getDouble(1)
}
}
强类型:继承Aggregator(org.apache.spark.sql.expressions.Aggregator)
package sparkSql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Encoder, Encoders, Row, SparkSession, functions}
object sparkSqlUdf_new {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSql")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
val df: DataFrame = spark.read.json("datas/sql")
spark.udf.register("udaf", functions.udaf(new UDAF()))
df.createOrReplaceTempView("user")
spark.sql("select name , udaf(age)+1 as newAge from user group by name").show()
spark.stop()
}
}
case class Buff(var sum: Long, var cnt: Long)
class UDAF extends Aggregator[Long, Buff, Long] {
override def zero: Buff = {
Buff(0L, 0L)
}
override def reduce(buff: Buff, a: Long): Buff = {
buff.cnt = buff.cnt + 1L
buff.sum = buff.sum + a
buff
}
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.sum = buff1.sum + buff2.sum
buff1.cnt = buff1.cnt + buff2.cnt
buff1
}
override def finish(reduction: Buff): Long =
reduction.sum / reduction.cnt
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}