聚合
groupBy
groupBy
算子会按照列将 Dataset
分组, 并返回一个 RelationalGroupedDataset
对象, 通过 RelationalGroupedDataset
可以对分组进行聚合
// 1 准备数据
val spark = SparkSession.builder()
.master("local[6]")
.appName("aggregation")
.getOrCreate()
import spark.implicits._
private val schema = StructType(
List(
StructField("id", IntegerType),
StructField("year", IntegerType),
StructField("month", IntegerType),
StructField("day", IntegerType),
StructField("hour", IntegerType),
StructField("season", IntegerType),
StructField("pm", DoubleType)
)
)
val pmDF = spark.read
.schema(schema)
.option("header", value = true)
.csv("dataset/pm_without_null.csv")
// 2 使用functions函数进行聚合
import org.apache.spark.sql.functions._
val groupedDF: RelationalGroupedDataset = pmDF.groupBy('year)
groupedDF.agg(avg('pm) as "pm_avg")
.orderBy('pm_avg.desc)
.show()
// 3 除了使用functions进行聚合,还可以直接使用RelationalGroupedDataset的API进行聚合
groupedDF.avg("pm")
.select($"avg(pm)" as "pm_avg")
.orderBy('pm_avg)
.show()
groupedDF.max("pm")
.select($"max(pm)" as "pm_max")
.orderBy('pm_max)
.show()
多维聚合
我们可能经常需要针对数据进行多维的聚合, 也就是一次性统计小计, 总计等, 一般的思路如下
// 1 准备数据
val spark = SparkSession.builder()
.master("local[6]")
.appName("aggregation")
.getOrCreate()
import spark.implicits._
val schemaFinal = StructType(
List(
StructField("source", StringType),
StructField("year", IntegerType),
StructField("month", IntegerType),
StructField("day", IntegerType),
StructField("hour", IntegerType),
StructField("season", IntegerType),
StructField("pm", DoubleType)
)
)
val pmFinal = spark.read
.schema(schemaFinal)
.option("header", value = true)
.csv("dataset/pm_final.csv")
// 2 进行多维聚合
import org.apache.spark.sql.functions._
val groupPostAndYear = pmFinal.groupBy('source, 'year)
.agg(sum("pm") as "pm")
val groupPost = pmFinal.groupBy('source)
.agg(sum("pm") as "pm")
.select('source, lit(null) as "year", 'pm)
groupPostAndYear.union(groupPost)
.sort('source, 'year asc_nulls_last, 'pm)
.show()
//得到结果
/*
+-------+----+------------------+
| source|year| pm|
+-------+----+------------------+
| dongsi|2013| 93.2090724784592|
| dongsi|2014| 87.08640822045773|
| dongsi|2015| 87.4922056770591|
| dongsi|null| 89.15443876736389|
|us_post|2010|104.04572982326042|
|us_post|2011| 99.0932403834184|
|us_post|2012| 90.53876763535511|
|us_post|2013|101.71110855035722|
|us_post|2014| 97.73409537004964|
|us_post|2015| 82.78472946356158|
|us_post|null| 95.90424117331851|
+-------+----+------------------+
*/
其实可以看出来, 在一个数据集中又小计又总计, 可能需要多个操作符, 如何简化呢? 请看下面
rollup
rollup
操作符其实就是 groupBy
的一个扩展, rollup
会对传入的列进行滚动 groupBy
, groupBy
的次数为列数量 + 1
, 最后一次是对整个数据集进行聚合
def rollup(): Unit = {
import org.apache.spark.sql.functions._
val df = Seq(
("Beijing", 2016, 100),
("Beijing", 2017, 200),
("Shanghai", 2015, 50),
("Shanghai", 2016, 150),
("Guangzhou", 2017, 50)
).toDF("city", "year", "amount")
// 滚动分组, A, B 两列, AB, A, null
df.rollup('city, 'year)
.agg(sum('amount) as "amount")
.sort('city.asc_nulls_last, 'year.asc_nulls_last)
.show()
}
/**
* 结果集:
* +---------+----+------+
* | city|year|amount|
* +---------+----+------+
* | Shanghai|2015| 50| <-- 上海 2015 的小计
* | Shanghai|2016| 150|
* | Shanghai|null| 200| <-- 上海的总计
* |Guangzhou|2017| 50|
* |Guangzhou|null| 50|
* | Beijing|2016| 100|
* | Beijing|2017| 200|
* | Beijing|null| 300| <-- 北京的总计
* | null|null| 550| <-- 整个数据集的总计
* +---------+----+------+
*/
如果使用基础的 groupBy 如何实现效果?
val cityAndYear = df
.groupBy("city", "year") // 按照 city 和 year 聚合
.agg(sum("amount") as "amount")
val city = df
.groupBy("city") // 按照 city 进行聚合
.agg(sum("amount") as "amount")
.select($"city", lit(null) as "year", $"amount")
val all = df
.groupBy() // 全局聚合
.agg(sum("amount") as "amount")
.select(lit(null) as "city", lit(null) as "year", $"amount")
cityAndYear
.union(city)
.union(all)
.sort($"city".desc_nulls_last, $"year".asc_nulls_last)
.show()
/**
* 统计结果:
* +---------+----+------+
* | city|year|amount|
* +---------+----+------+
* | Shanghai|2015| 50|
* | Shanghai|2016| 150|
* | Shanghai|null| 200|
* |Guan