参考链接:
spark-window-functions-rangebetween-dates
假设我们有以下数据:
from pyspark.sql import Row
from pyspark.sql.window import Window
from pyspark.sql.functions import mean, col
row = Row("name", "date", "score")
rdd = sc.parallelize([
row("Ali", "2020-01-01", 10.0),
row("Ali", "2020-01-02", 15.0),
row("Ali", "2020-01-03", 20.0),
row("Ali", "2020-01-04", 25.0),
row("Ali", "2020-01-05", 30.0),
row("Bob", "2020-01-01", 15.0),
row("Bob", "2020-01-02", 20.0),
row("Bob", "2020-01-03", 30.0)
])
df = rdd.toDF().withColumn("date", col("date").cast("date"))
我们使用分组的形式计算每个人的平均分,其他数据保留,则可用如下代码:
w1 = Window().partitionBy(col("name"))
df.withColumn("mean1", mean("score").over(w1)).show()
+----+----------+-----+------------------+
|name| date|score| mean1|
+----+----------+-----+------------------+
| Bob|2020-01-01| 15.0|21.666666666666668|
| Bob|2020-01-02| 20.0|21.666666666666668|
| Bob|2020-01-03| 30.0|21.666666666666668|
| Ali|2020-01-02| 15.0| 20.0|
| Ali|2020-01-05| 30.0| 20.0|
| Ali|2020-01-01| 10.0| 20.0|
| Ali|2020-01-03| 20.0| 20.0|
| Ali|2020-01-04| 25.0| 20.0|
+----+----------+-----+------------------+
从结果来看,新增加的一列mean1表示每个人所在的分组中所有分数的平均值。当然,你也可以求最大值、最小值或者方差之类的统计值。
下面我们来看一组变形的分组窗:
days = lambda i: i * 86400 # 一天转化为秒单位
w1 = Window().partitionBy(col("name"))
w2 = Window().partitionBy(col("name")).orderBy("date")
w3 = Window().partitionBy(col("name")).orderBy((col("date").cast("timestamp").cast("bigint")/3600/24)).rangeBetween(-4, 0)
w4 = Window().partitionBy(col("name")).orderBy("date").rowsBetween(Window.currentRow, 1)
w1就是常规的按照名字进行分组;w2在按照名字分组的基础上,对其组内的日期按照从早到晚进行排序;w3是在w2的基础上,增加了范围限制,限制在从前4天到当前日期的范围内;w4则是在w2的基础上增加了行参数的限制,在当前行到下一行范围内。
是不是还是有些迷糊,不慌,来看下按照这些分组窗统计的结果:
df.withColumn("mean1", mean("score").over(w1))\
.withColumn("mean2", mean("score").over(w2))\
.withColumn("mean3", mean("score").over(w3))\
.withColumn("mean4", mean("score").over(w4))\
.show()
+----+----------+-----+-----+------------------+------------------+-----+
|name| date|score|mean1| mean2| mean3|mean4|
+----+----------+-----+-----+------------------+------------------+-----+
| Bob|2020-01-01| 15.0| 30.0| 15.0| 15.0| 17.5|
| Bob|2020-01-02| 20.0| 30.0| 17.5| 17.5| 25.0|
| Bob|2020-01-03| 30.0| 30.0|21.666666666666668|21.666666666666668| 32.5|
| Bob|2020-01-04| 35.0| 30.0| 25.0| 25.0| 37.5|
| Bob|2020-01-05| 40.0| 30.0| 28.0| 28.0| 40.0|
| Bob|2020-01-06| 40.0| 30.0| 30.0| 33.0| 40.0|
| Ali|2020-01-01| 10.0| 20.0| 10.0| 10.0| 12.5|
| Ali|2020-01-02| 15.0| 20.0| 12.5| 12.5| 17.5|
| Ali|2020-01-03| 20.0| 20.0| 15.0| 15.0| 22.5|
| Ali|2020-01-04| 25.0| 20.0| 17.5| 17.5| 27.5|
| Ali|2020-01-05| 30.0| 20.0| 20.0| 20.0| 30.0|
+----+----------+-----+-----+------------------+------------------+-----+
我们来逐个分析一下,首先mean1列很简单,就是每个name分组内所有分数的平均值。mean2比较有意思,分组窗是按照name分组后按照日期进行了排序,于是均值是在当前行及前面所有行的范围内进行计算,这个可以看每组最后一个mean2均值,都与mean1均值相等。
mean3列是在当前行及往前数4天范围内计算均值,如Bob的最后一个mean3值是33,就是从2020-01-02开始计算的。
mean4列每次只统计当前行和下一行的数值,如果没有下一行则是其本身。
Window.unboundedPreceding
, Window.unboundedFollowing
, 以及Window.currentRow
分别用来表示前面所有行、后面所有行以及当前行,而数值的正负表示往前或往后,大小表示行数。
到这里可以稍微做一个总结:
1 单独的Window做聚合统计,仅对分组内所有数值进行计算;
2 添加orderBy排序的Window分组窗,统计时默认是从前面所有行到当前行进行计算;
3 rangeBetween结合orderBy可用来限制指定范围内的数据,例如统计一周内数据的场景;
4 rowsBetween用来限定前后指定行范围内的数据进行统计