窗口函数的使用(1)
窗口是非常重要的统计工具,很多数据库都支持窗口函数。Spark从1.4开始支持窗口(window)函数。它主要有以下一些特点:
- 先对在一组数据行上进行操作,这组数据被称为Frame。
- 一个Frame对应当前处理的行
- 通过聚合/窗口函数为每行返回一个新值
- 可以使用SQL语法或DataFrame API。
准备工作
- 准备依赖库
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
- 准备数据
case class Salary(depName: String, empNo: Long, name: String,
salary: Long, hobby: Seq[String])
val empsalary = Seq(
Salary("sales", 1, "Alice", 5000, List("game", "ski")),
Salary("personnel", 2, "Olivia", 3900, List("game", "ski")),
Salary("sales", 3, "Ella", 4800, List("skate", "ski")),
Salary("sales", 4, "Ebba", 4800, List("game", "ski")),
Salary("personnel", 5, "Lilly", 3500, List("climb", "ski")),
Salary("develop", 7, "Astrid", 4200, List("game", "ski")),
Salary("develop", 8, "Saga", 6000, List("kajak", "ski")),
Salary("develop", 9, "Freja", 4500, List("game", "kajak")),
Salary("develop", 10, "Wilma", 5200, List("game", "ski")),
Salary("develop", 11, "Maja", 5200, List("game", "farming"))).toDS
empsalary.createTempView("empsalary")
empsalary.show()
数据输出如下:
+---------+-----+------+------+---------------+
| depName|empNo| name|salary| hobby|
+---------+-----+------+------+---------------+
| sales| 1| Alice| 5000| [game, ski]|
|personnel| 2|Olivia| 3900| [game, ski]|
| sales| 3| Ella| 4800| [skate, ski]|
| sales| 4| Ebba| 4800| [game, ski]|
|personnel| 5| Lilly| 3500| [climb, ski]|
| develop| 7|Astrid| 4200| [game, ski]|
| develop| 8| Saga| 6000| [kajak, ski]|
| develop| 9| Freja| 4500| [game, kajak]|
| develop| 10| Wilma| 5200| [game, ski]|
| develop| 11| Maja| 5200|[game, farming]|
+---------+-----+------+------+---------------+
注:例子代码来源于网络。
基本Frame操作
通过partitionBy
来获取基本Frame,然后基于获取到的Frame,来进行各种操作。
一个基本的Frame有以下的特征:
- 通过在一列或多列上调用函数
Window.partitionBy
来创建Frame; - 每一行都有与之对应的frame;
- 在同一个分区中,每一行的frame相同,但Ordered frame除外。
- Aggregate/Window函数能够使用在frame的每个行上,并得到一个值。
在示例数据中,我们来计算以下的值:
- 使用函数avg计算部门的平均工资
- 使用函数sum来计算部门的总工资
代码如下:
val overCategory = Window.partitionBy('depName)
// 基于窗口求平均值,和,最大值,最小值
val df = empsalary.withColumn("salaries", collect_list('salary) over overCategory).
withColumn("average_salary", (avg('salary) over overCategory).cast("int")).
withColumn("total_salary", sum('salary) over overCategory).
withColumn("min", min('salary) over overCategory).
withColumn("max", max('salary) over overCategory).
select("depName", "empNo", "name", "salary", "salaries", "average_salary", "total_salary","min","max")
df.show(false)
代码分步说明:
(1)创建一个Frame:val overCategory = Window.partitionBy('depName)
(2)基于Frame来计算平均值:
withColumn("average_salary", (avg('salary) over overCategory).cast("int")).
(3)基于Frame来求和,最大值,最小值等
withColumn("total_salary", sum('salary) over overCategory).
结果分析:
+---------+-----+------+------+--------------------+--------------+------------+----+----+
| depName|empNo| name|salary| salaries|average_salary|total_salary| min| max|
+---------+-----+------+------+--------------------+--------------+------------+----+----+
| develop| 7|Astrid| 4200|[4200, 6000, 4500...| 5020| 25100|4200|6000|
| develop| 8| Saga| 6000|[4200, 6000, 4500...| 5020| 25100|4200|6000|
| develop| 9| Freja| 4500|[4200, 6000, 4500...| 5020| 25100|4200|6000|
| develop| 10| Wilma| 5200|[4200, 6000, 4500...| 5020| 25100|4200|6000|
| develop| 11| Maja| 5200|[4200, 6000, 4500...| 5020| 25100|4200|6000|
| sales| 1| Alice| 5000| [5000, 4800, 4800]| 4866| 14600|4800|5000|
| sales| 3| Ella| 4800| [5000, 4800, 4800]| 4866| 14600|4800|5000|
| sales| 4| Ebba| 4800| [5000, 4800, 4800]| 4866| 14600|4800|5000|
|personnel| 2|Olivia| 3900| [3900, 3500]| 3700| 7400|3500|3900|
|personnel| 5| Lilly| 3500| [3900, 3500]| 3700| 7400|3500|3900|
+---------+-----+------+------+--------------------+--------------+------------+----+----+
从 结果中可以看出,结果是没有去重的,每一行都会有一个最后的结果。而且在窗口的每个Frame中的每行 结果 都不同。
对Frame进行排序操作
有序的Frame通过partitionBy and orderBy来进行创建。有序的Frame有以下几个特征:
- 通过Window.partitionBy函数添加一个或多个列来创建
- 在partitionBy函数后面添加orderBy列
- 每一行都对应一个frame
- frame的行和相同分区的每一行不同。默认,frame包含包含 所有的前面的行和目前行。
- Aggregate/Window函数可以用到每一行row+frame,并产生一个值 。
示例代码
// 创建有序窗口
val overCategory = Window.partitionBy('depName).orderBy('salary desc)
// 对有序窗口进行运算
val df = empsalary.withColumn("salaries", collect_list('salary) over overCategory).
withColumn("avg_salary", (avg('salary) over overCategory).cast("int")).
withColumn("total_salary", sum('salary) over overCategory).
select("depName", "empNo", "name", "salary", "salaries", "avg_salary", "total_salary")
df.show(false)
结果输出
scala> df.show(false)
+---------+-----+------+------+------------------------------+----------+------------+
|depName |empNo|name |salary|salaries |avg_salary|total_salary|
+---------+-----+------+------+------------------------------+----------+------------+
|develop |8 |Saga |6000 |[6000] |6000 |6000 |
|develop |10 |Wilma |5200 |[6000, 5200, 5200] |5466 |16400 |
|develop |11 |Maja |5200 |[6000, 5200, 5200] |5466 |16400 |
|develop |9 |Freja |4500 |[6000, 5200, 5200, 4500] |5225 |20900 |
|develop |7 |Astrid|4200 |[6000, 5200, 5200, 4500, 4200]|5020 |25100 |
|sales |1 |Alice |5000 |[5000] |5000 |5000 |
|sales |3 |Ella |4800 |[5000, 4800, 4800] |4866 |14600 |
|sales |4 |Ebba |4800 |[5000, 4800, 4800] |4866 |14600 |
|personnel|2 |Olivia|3900 |[3900] |3900 |3900 |
|personnel|5 |Lilly |3500 |[3900, 3500] |3700 |7400 |
+---------+-----+------+------+------------------------------+----------+------------+
从结果可以看出,基于排序窗口进行的计算不再基于划分窗口的字段,而是基于排序的字段。排序的字段值相同时,窗口函数计算的结果相同,否则每一行的结果都不同。
而且,排序窗口的聚合计算具有累加性,排序好序的窗口中的每个值都认为是一个分组。
以上例子是通过salary字段值进行的排序,在进行窗口计算时,认为该字段排好序的每个值都是一个分组。
在分组中进行排名
下表列出了Spark支持的所有等级函数。
函数 | 描述(在分区窗口内) | 注意 |
---|---|---|
rank | 对每个分组行进行排名 | 若排序字段的数据重复,则跳过该排名。例如:1 2 2 4…,若第2个和第3个值相同,第4行会排在第4位,这样第3位就会被跳过。 |
dense_rank | 对每个分组进行排名 | 若排序字段的数据重复,不会跳过排名。例如:1 2 2 4…,在第4行的排名是3。 |
row_number | 行编号 | 对同一个窗口的排序行进行编号。 |
ntile | Ntile id | 在窗口的每个分组内,再把数据分成多个堆。参数是堆的个数,比如:5行的分组,分成2对,会先按3个一堆进行分堆。 |
percent_rank | (rank-1)/(total_rows-1) | 当需要获取窗口中每个分组的前25%的数据时,占比时很有用。 |
在窗口中使用rank()和dense_rank()函数
// 创建排序窗口
val overCategory = Window.partitionBy('depName).orderBy('salary desc)
// 在排序窗口中使用rank函数
val df = empsalary.withColumn("salaries", collect_list('salary) over overCategory).
withColumn("rank", rank() over overCategory).
withColumn("dense_rank", dense_rank() over overCategory).
withColumn("row_number", row_number() over overCategory).
withColumn("ntile", ntile(2) over overCategory).
withColumn("percent_rank", percent_rank() over overCategory).
select("depName", "empNo", "name", "salary", "rank", "dense_rank", "row_number", "ntile", "percent_rank")
执行完这两行代码后,就得到了一下的df:
scala> df.show(false)
+---------+-----+------+------+----+----------+----------+-----+------------+
|depName |empNo|name |salary|rank|dense_rank|row_number|ntile|percent_rank|
+---------+-----+------+------+----+----------+----------+-----+------------+
|develop |8 |Saga |6000 |1 |1 |1 |1 |0.0 |
|develop |10 |Wilma |5200 |2 |2 |2 |1 |0.25 |
|develop |11 |Maja |5200 |2 |2 |3 |1 |0.25 |
|develop |9 |Freja |4500 |4 |3 |4 |2 |0.75 |
|develop |7 |Astrid|4200 |5 |4 |5 |2 |1.0 |
|sales |1 |Alice |5000 |1 |1 |1 |1 |0.0 |
|sales |3 |Ella |4800 |2 |2 |2 |1 |0.5 |
|sales |4 |Ebba |4800 |2 |2 |3 |2 |0.5 |
|personnel|2 |Olivia|3900 |1 |1 |1 |1 |0.0 |
|personnel|5 |Lilly |3500 |2 |2 |2 |2 |1.0 |
+---------+-----+------+------+----+----------+----------+-----+------------+
从以上结果可以看出,在使用rank()函数时empNo为9的行排在了第4位,跳过了第3位。而dense_rank()却不会跳过。这样可以满足多种分组需求。
sql形式
可以把dataframe的形式,写成spark sql,如下:
SELECT
depName,
empNo,
name,
salary,
rank,
dense_rank,
row_number,
ntile,
percent_rank
FROM (
SELECT
depName,
empNo,
name,
salary,
rank() OVER (PARTITION BY depName ORDER BY salary DESC) as rank,
dense_rank() OVER (PARTITION BY depName ORDER BY salary DESC) as dense_rank,
ntile(3) OVER (PARTITION BY depName ORDER BY salary DESC) as ntile,
row_number() OVER (PARTITION BY depName ORDER BY salary DESC) as row_number,
percent_rank() OVER (PARTITION BY depName ORDER BY salary DESC) as percent_rank
FROM empsalary) tmp
WHERE
rank <= 2
这里要注意,该SQL的最后添加了一个条件,通过添加条件可以很容易过滤数据,从而得到自己想要的结果。比如,我们希望能够通过窗口函数来计算:每个部门薪水排在前2名的员工。
窗口函数实战
找出每个部门薪水排在前3名的所有员工,注意:是所有员工。
解题思路:要找出每个部门薪水排在前3的所有员工,就首先需要按部门对员工进行分组。然后对组内的员工进行按薪水降序排序,再取前3名。
另外要注意,这里是找出所有的员工,包括重复的,所以应该使用dense_rank()函数,这样就不会溜掉薪水相同的员工名单。
sql形式的解法
SELECT
depName,
empNo,
name,
salary,
dense_rank
FROM (
SELECT
depName,
empNo,
name,
salary,
dense_rank() OVER (PARTITION BY depName ORDER BY salary DESC) as dense_rank
FROM empsalary) tmp
WHERE
dense_rank <= 3
dataframe API形式的解法
// 创建排序窗口
val overCategory = Window.partitionBy('depName).orderBy('salary desc)
// 在排序窗口中使用rank函数
val df = empsalary.withColumn("dense_rank", dense_rank() over overCategory).
filter('dense_rank <= 3).
select("depName", "empNo", "name", "salary", "dense_rank")
小结
本文介绍了 基本窗口函数的使用,并介绍了排序窗口和基本窗口的不同点,并通过实际的例子进行了分析。