【Python笔记】SparkSQL の 窗口函数

1 spark.sql中の应用

1.1 基础数据

from pyspark.sql.types import *


schema = StructType().add('name', StringType(), True).add('create_time', TimestampType(), True).add('department', StringType(), True).add('salary', IntegerType(), True)
df = spark.createDataFrame([
    ("Tom", datetime.strptime("2020-01-01 00:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4500),
    ("Georgi", datetime.strptime("2020-01-02 12:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4200),
    ("Kyoichi", datetime.strptime("2020-02-02 12:10:00", "%Y-%m-%d %H:%M:%S"), "Sales", 3000),    
    ("Berni", datetime.strptime("2020-01-10 11:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4700),
    ("Berni", datetime.strptime("2020-01-07 11:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", None),    
    ("Guoxiang", datetime.strptime("2020-01-08 12:11:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4200),   
    ("Parto", datetime.strptime("2020-02-20 12:01:00", "%Y-%m-%d %H:%M:%S"), "Finance", 2700),
    ("Anneke", datetime.strptime("2020-01-02 08:20:00", "%Y-%m-%d %H:%M:%S"), "Finance", 3300),
    ("Sumant", datetime.strptime("2020-01-30 12:01:05", "%Y-%m-%d %H:%M:%S"), "Finance", 3900),
    ("Jeff", datetime.strptime("2020-01-02 12:01:00", "%Y-%m-%d %H:%M:%S"), "Marketing", 3100),
    ("Patricio", datetime.strptime("2020-01-05 12:18:00", "%Y-%m-%d %H:%M:%S"), "Marketing", 2500)
], schema=schema)
df.createOrReplaceTempView('salary')
df.show()
+--------+-------------------+----------+------+
|    name|        create_time|department|salary|
+--------+-------------------+----------+------+
|     Tom|2020-01-01 00:01:00|     Sales|  4500|
|  Georgi|2020-01-02 12:01:00|     Sales|  4200|
| Kyoichi|2020-02-02 12:10:00|     Sales|  3000|
|   Berni|2020-01-10 11:01:00|     Sales|  4700|
|   Berni|2020-01-07 11:01:00|     Sales|  null|
|Guoxiang|2020-01-08 12:11:00|     Sales|  4200|
|   Parto|2020-02-20 12:01:00|   Finance|  2700|
|  Anneke|2020-01-02 08:20:00|   Finance|  3300|
|  Sumant|2020-01-30 12:01:05|   Finance|  3900|
|    Jeff|2020-01-02 12:01:00| Marketing|  3100|
|Patricio|2020-01-05 12:18:00| Marketing|  2500|
+--------+-------------------+----------+------+

1.2 窗口函数

ranking functions

sqlDataFrame功能
row_numberrowNumber从1~n的唯一序号值
rankrank与denseRank一样,都是排名,对于相同的数值,排名一致。区别:rank会跳过并列的排名
dense_rankdenseRank不会跳过并列的排名
percent_rankpercentRank计算公式: (组内排名-1)/(组内行数-1),如果组内只有1行,则结果为0
ntilentile将组内数据排序后,按照指定的n切分为n个桶,该值为当前行的桶号(桶号从1开始)
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,rank() over(partition by department order by salary) as rank
    ,dense_rank() over(partition by department order by salary) as dense_rank
    ,percent_rank() over(partition by department order by salary) as percent_rank
    ,ntile(2) over(partition by department order by salary) as ntile
FROM salary
""").toPandas()
namedepartmentsalaryindexrankdense_rankpercent_rankntile
0PatricioMarketing2500.01110.01
1JeffMarketing3100.02221.02
2BerniSalesNaN1110.01
3KyoichiSales3000.02220.21
4GeorgiSales4200.03330.41
5GuoxiangSales4200.04330.42
6TomSales4500.05540.82
7BerniSales4700.06651.02
8PartoFinance2700.01110.01
9AnnekeFinance3300.02220.51
10SumantFinance3900.03331.02

analytic functions

sqlDataFrame功能
cume_distcumeDist计算公式: 组内小于等于值当前行数/组内总行数
laglaglag(input, [offset,[default]]) 当前index<offset返回defalult(默认defalult=null), 否则返回input
leadlead与lag相反
first_valuefirst_value取分组内排序后,截止到当前行,第一个值
last_valuelast_value取分组内排序后,截止到当前行,最后一个值
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,cume_dist() over(partition by department order by salary) as cume_dist
    ,lag(salary, 1) over(partition by department order by salary) as lag -- 当前行向上
    ,lead(salary, 1) over(partition by department order by salary) as lead -- 当前行向下
    ,lag(salary, 0) over(partition by department order by salary) as lag_0
    ,lead(salary, 0) over(partition by department order by salary) as lead_0
    ,first_value(salary) over(partition by department order by salary) as first_value
    ,last_value(salary) over(partition by department order by salary) as last_value 
FROM salary
""").toPandas()
namedepartmentsalaryindexcume_distlagleadlag_0lead_0first_valuelast_value
0PatricioMarketing2500.010.500000NaN3100.02500.02500.02500.02500.0
1JeffMarketing3100.021.0000002500.0NaN3100.03100.02500.03100.0
2BerniSalesNaN10.166667NaN3000.0NaNNaNNaNNaN
3KyoichiSales3000.020.333333NaN4200.03000.03000.0NaN3000.0
4GeorgiSales4200.030.6666673000.04200.04200.04200.0NaN4200.0
5GuoxiangSales4200.040.6666674200.04500.04200.04200.0NaN4200.0
6TomSales4500.050.8333334200.04700.04500.04500.0NaN4500.0
7BerniSales4700.061.0000004500.0NaN4700.04700.0NaN4700.0
8PartoFinance2700.010.333333NaN3300.02700.02700.02700.02700.0
9AnnekeFinance3300.020.6666672700.03900.03300.03300.02700.03300.0
10SumantFinance3900.031.0000003300.0NaN3900.03900.02700.03900.0

aggregate functions

只是在一定窗口里实现一些普通的聚合函数

sql功能
avg平均值
sum求和
min最小值
max最大值
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,sum(salary) over(partition by department order by salary) as sum
    ,avg(salary) over(partition by department order by salary) as avg
    ,min(salary) over(partition by department order by salary) as min
    ,max(salary) over(partition by department order by salary) as max
FROM salary
""").toPandas()
namedepartmentsalaryindexsumavgminmax
0PatricioMarketing2500.012500.02500.02500.02500.0
1JeffMarketing3100.025600.02800.02500.03100.0
2BerniSalesNaN1NaNNaNNaNNaN
3KyoichiSales3000.023000.03000.03000.03000.0
4GeorgiSales4200.0311400.03800.03000.04200.0
5GuoxiangSales4200.0411400.03800.03000.04200.0
6TomSales4500.0515900.03975.03000.04500.0
7BerniSales4700.0620600.04120.03000.04700.0
8PartoFinance2700.012700.02700.02700.02700.0
9AnnekeFinance3300.026000.03000.02700.03300.0
10SumantFinance3900.039900.03300.02700.03900.0

1.3 窗口子句

ROWS/RANG窗口子句: 用于控制窗口的尺寸边界,有两种(ROW,RANGE)

  • ROWS: 物理窗口,数据筛选基于排序后的index
  • RANGE: 逻辑窗口,数据筛选基于值

语法:OVER (PARTITION BY … ORDER BY … frame_type BETWEEN start AND end)

有以下5种边界

  • CURRENT ROW:
  • UNBOUNDED PRECEDING: 分区第一行
  • UNBOUNDED FOLLOWING: 分区最后一行
  • n PRECEDING: 当前行,向前n行
  • n FOLLOWING: 当前行,向后n行
  • UNBOUNDED: 起点
spark.sql("""
SELECT
    name 
    ,department
    ,create_time
    ,row_number() over(partition by department order by create_time) as index
    ,row_number() over(partition by department order by (case when salary is not null then create_time end)) as index_ignore_null
    ,salary    
    ,collect_list(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salarys
    ,last(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salary1
    ,lag(salary, 1) over(partition by department order by create_time) as before_salary2
    ,lead(salary, 1) over(partition by department order by create_time) as after_salary   
FROM salary
ORDER BY department, index
""").toPandas()
namedepartmentcreate_timeindexindex_ignore_nullsalarybefore_salarysbefore_salary1before_salary2after_salary
0AnnekeFinance2020-01-02 08:20:00113300.0[]NaNNaN3900.0
1SumantFinance2020-01-30 12:01:05223900.0[3300]3300.03300.02700.0
2PartoFinance2020-02-20 12:01:00332700.0[3300, 3900]3900.03900.0NaN
3JeffMarketing2020-01-02 12:01:00113100.0[]NaNNaN2500.0
4PatricioMarketing2020-01-05 12:18:00222500.0[3100]3100.03100.0NaN
5TomSales2020-01-01 00:01:00124500.0[]NaNNaN4200.0
6GeorgiSales2020-01-02 12:01:00234200.0[4500]4500.04500.0NaN
7BerniSales2020-01-07 11:01:0031NaN[4500, 4200]4200.04200.04200.0
8GuoxiangSales2020-01-08 12:11:00444200.0[4500, 4200]NaNNaN4700.0
9BerniSales2020-01-10 11:01:00554700.0[4500, 4200, 4200]4200.04200.03000.0
10KyoichiSales2020-02-02 12:10:00663000.0[4500, 4200, 4200, 4700]4700.04700.0NaN
# 同一个部门,上个非空工资入职同事的收入
spark.sql("""
SELECT
    name
    ,department
    ,create_time
    ,index
    ,salary
    ,before_salarys[size(before_salarys)-1] as before_salary
FROM(
    SELECT
        name 
        ,department
        ,create_time
        ,row_number() over(partition by department order by create_time) as index
        ,salary    
        ,collect_list(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salarys 
    FROM salary
    ORDER BY department, index
) AS base
""").toPandas()
namedepartmentcreate_timeindexsalarybefore_salary
0AnnekeFinance2020-01-02 08:20:0013300.0NaN
1SumantFinance2020-01-30 12:01:0523900.03300.0
2PartoFinance2020-02-20 12:01:0032700.03900.0
3JeffMarketing2020-01-02 12:01:0013100.0NaN
4PatricioMarketing2020-01-05 12:18:0022500.03100.0
5TomSales2020-01-01 00:01:0014500.0NaN
6GeorgiSales2020-01-02 12:01:0024200.04500.0
7BerniSales2020-01-07 11:01:003NaN4200.0
8GuoxiangSales2020-01-08 12:11:0044200.04200.0
9BerniSales2020-01-10 11:01:0054700.04200.0
10KyoichiSales2020-02-02 12:10:0063000.04700.0

1.4 混合应用

spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,salary - (min(salary) over(partition by department order by salary)) as salary_diff -- 比部门最低工资高多少
    ,min(salary) over() as min_salary_0 -- 最小工资
    ,first_value(salary) over(order by salary) as max_salary_1
    
    ,max(salary) over(order by salary) as current_max_salary_0 -- 截止到当前最大工资
    ,last_value(salary) over(order by salary) as current_max_salary_1 
    
    ,max(salary) over(partition by department order by salary rows between 1 FOLLOWING and 1 FOLLOWING) as next_salary_0 -- 按照salary排序下一条记录
    ,lead(salary) over(partition by department order by salary) as next_salary_1
FROM salary
WHERE salary is not null
""").toPandas()
namedepartmentsalaryindexsalary_diffmin_salary_0max_salary_1current_max_salary_0current_max_salary_1next_salary_0next_salary_1
0PatricioMarketing25001025002500250025003100.03100.0
1PartoFinance27001025002500270027003300.03300.0
2KyoichiSales30001025002500300030004200.04200.0
3JeffMarketing310026002500250031003100NaNNaN
4AnnekeFinance3300260025002500330033003900.03900.0
5SumantFinance3900312002500250039003900NaNNaN
6GeorgiSales42002120025002500420042004200.04200.0
7GuoxiangSales42003120025002500420042004500.04500.0
8TomSales45004150025002500450045004700.04700.0
9BerniSales4700517002500250047004700NaNNaN

Reference: SparkSQL | 窗口函数

2 pyspark

文中对Window函数分类为三种:ranking functions,analytic functions,aggregate functions

  • ranking functions包括row_number(),rank(),dense_rank(),percent_rank(),ntile();
  • analytic functions包括cume_dist(),lag(), lead();
  • aggregate functions包括sum(),first(),last(),max(),min(),mean(),stddev()等。

2.1 Ranking functions

首先,假设我们的数据是如下形式:

# spark = SparkSession.builder.appName('Window functions').getOrCreate()
employee_salary = [
    ("Ali", "Sales", 8000),
    ("Bob", "Sales", 7000),
    ("Cindy", "Sales", 7500),
    ("Davd", "Finance", 10000),
    ("Elena", "Sales", 8000),
    ("Fancy", "Finance", 12000),
    ("George", "Finance", 11000),
    ("Haffman", "Marketing", 7000),
    ("Ilaja", "Marketing", 8000),
    ("Joey", "Sales", 9000)]
 
columns= ["name", "department", "salary"]
df = spark.createDataFrame(data = employee_salary, schema = columns)
df.show(truncate=False)
+-------+----------+------+
|name   |department|salary|
+-------+----------+------+
|Ali    |Sales     |8000  |
|Bob    |Sales     |7000  |
|Cindy  |Sales     |7500  |
|Davd   |Finance   |10000 |
|Elena  |Sales     |8000  |
|Fancy  |Finance   |12000 |
|George |Finance   |11000 |
|Haffman|Marketing |7000  |
|Ilaja  |Marketing |8000  |
|Joey   |Sales     |9000  |
+-------+----------+------+

row_number()

row_number() 可以用来给按照指定列排序的分组窗增加一个行序号,这个列从1开始依次递增,序数是依据分组窗的指定排序列依次从小到大变化。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("row_number", F.row_number().over(windowSpec)).show(truncate=False)

按照部门对数据进行分组,然后按照薪水由高到低进行排序,结果如下:

+-------+----------+------+----------+
|name   |department|salary|row_number|
+-------+----------+------+----------+
|Joey   |Sales     |9000  |1         |
|Ali    |Sales     |8000  |2         |
|Elena  |Sales     |8000  |3         |
|Cindy  |Sales     |7500  |4         |
|Bob    |Sales     |7000  |5         |
|Fancy  |Finance   |12000 |1         |
|George |Finance   |11000 |2         |
|Davd   |Finance   |10000 |3         |
|Ilaja  |Marketing |8000  |1         |
|Haffman|Marketing |7000  |2         |
+-------+----------+------+----------+

观察上面的数据,你会发现,同样的薪水会有不同的行号,这是因为row_number() 是按照行来给定序号,其不关注实际数值的大小。由此我们可以引申出另一个用于给出排序数的函数rank。

rank()

rank() 用来给按照指定列排序的分组窗增加一个排序的序号,如果有相同数值,则排序数相同,下一个序数顺延一位。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("rank",F.rank().over(windowSpec)).show(truncate=False)

按照部门进行分组,组内对薪水按照从高到低进行排序,结果如下:

+-------+----------+------+----+
|name   |department|salary|rank|
+-------+----------+------+----+
|Joey   |Sales     |9000  |1   |
|Ali    |Sales     |8000  |2   |
|Elena  |Sales     |8000  |2   |
|Cindy  |Sales     |7500  |4   |
|Bob    |Sales     |7000  |5   |
|Fancy  |Finance   |12000 |1   |
|George |Finance   |11000 |2   |
|Davd   |Finance   |10000 |3   |
|Ilaja  |Marketing |8000  |1   |
|Haffman|Marketing |7000  |2   |
+-------+----------+------+----+

上面的结果我们观察到,两个相同的8000排序都是2,而下一档排序数自然顺延至4了。说到这,不得不提另一个排序数函数dense_rank()

dense_rank()

dense_rank() 函数也是对分组窗进行排序,分组窗需指定排序列,排序时不考虑顺延,同样的值序号一致,后续数值排序不受影响。我们来看如下代码:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("dense_rank",F.dense_rank().over(windowSpec)).show()

按照部门进行分组,对组内的数据按照薪水进行从高到低进行排序,结果如下:

+-------+----------+------+----------+
|   name|department|salary|dense_rank|
+-------+----------+------+----------+
|   Joey|     Sales|  9000|         1|
|    Ali|     Sales|  8000|         2|
|  Elena|     Sales|  8000|         2|
|  Cindy|     Sales|  7500|         3|
|    Bob|     Sales|  7000|         4|
|  Fancy|   Finance| 12000|         1|
| George|   Finance| 11000|         2|
|   Davd|   Finance| 10000|         3|
|  Ilaja| Marketing|  8000|         1|
|Haffman| Marketing|  7000|         2|
+-------+----------+------+----------+

percent_rank()

一些业务场景下,我们需要计算不同数值的百分比排序数据。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("percent_rank",F.percent_rank().over(windowSpec)).show()

按照部门进行分组,然后在组内对每个人的薪水进行排序,使用percent_rank() 增加排序列,结果如下:

+-------+----------+------+------------+
|   name|department|salary|percent_rank|
+-------+----------+------+------------+
|   Joey|     Sales|  9000|         0.0|
|    Ali|     Sales|  8000|        0.25|
|  Elena|     Sales|  8000|        0.25|
|  Cindy|     Sales|  7500|        0.75|
|    Bob|     Sales|  7000|         1.0|
|  Fancy|   Finance| 12000|         0.0|
| George|   Finance| 11000|         0.5|
|   Davd|   Finance| 10000|         1.0|
|  Ilaja| Marketing|  8000|         0.0|
|Haffman| Marketing|  7000|         1.0|
+-------+----------+------+------------+

上述结果可以理解为将dense_rank() 的结果进行归一化,即可得到0-1以内的百分数。percent_rank() 与SQL中的 PERCENT_RANK 函数效果一致。

ntile()

ntile()可将分组的数据按照指定数值n切分为n个部分,每一部分按照行的先后给定相同的序数。例如n指定为2,则将组内数据分为两个部分,第一部分序号为1,第二部分序号为2。理论上两部分数据行数是均等的,但当数据为奇数行时,中间的那一行归到前一部分。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("ntile",F.ntile(2).over(windowSpec)).show()

按照部门对数据进行分组,然后在组内按照薪水高低进行排序,再使用ntile() 将组内数据切分为两个部分。结果如下:

+-------+----------+------+-----+
|   name|department|salary|ntile|
+-------+----------+------+-----+
|   Joey|     Sales|  9000|    1|
|    Ali|     Sales|  8000|    1|
|  Elena|     Sales|  8000|    1|
|  Cindy|     Sales|  7500|    2|
|    Bob|     Sales|  7000|    2|
|  Fancy|   Finance| 12000|    1|
| George|   Finance| 11000|    1|
|   Davd|   Finance| 10000|    2|
|  Ilaja| Marketing|  8000|    1|
|Haffman| Marketing|  7000|    2|
+-------+----------+------+-----+

2.2 Analytic functions

cume_dist()

cume_dist()函数用来获取数值的累进分布值

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("cume_dist",F.cume_dist().over(windowSpec)).show()

按照部门进行分组,对薪水进行排序,然后cume_dist()获取累进分布值,结果如下:

+-------+----------+------+------------------+
|   name|department|salary|         cume_dist|
+-------+----------+------+------------------+
|   Joey|     Sales|  9000|               0.2|
|    Ali|     Sales|  8000|               0.6|
|  Elena|     Sales|  8000|               0.6|
|  Cindy|     Sales|  7500|               0.8|
|    Bob|     Sales|  7000|               1.0|
|  Fancy|   Finance| 12000|0.3333333333333333|
| George|   Finance| 11000|0.6666666666666666|
|   Davd|   Finance| 10000|               1.0|
|  Ilaja| Marketing|  8000|               0.5|
|Haffman| Marketing|  7000|               1.0|
+-------+----------+------+------------------+

结果好像和前面的percent_rank()很类似对不对,于是我们联想到这个其实也是一种归一化结果,其按照rank() 的结果进行归一化处理。回想一下前面讲过的rank() 函数,并列排序会影响后续排序,于是序号中间可能存在隔断。这样Sales组的排序数就是1、2、2、4、5,归一化以后就得到了0.2、0.6、0.6、0.8、1。这个统计结果按照实际业务来理解就是:9000及以上的人占了20%,8000及以上的人占了60%,7500以上的人数占了80%,7000以上的人数占了100%,这样是不是就好理解多了。

lag()

lag() 函数用于寻找按照指定列排好序的分组内每个数值的上一个数值,通俗的说,就是数值排好序以后,寻找排在每个数值的上一个数值。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("lag",F.lag("salary",1).over(windowSpec)).show()

按照部门进行分类,并按照薪水在组内进行排序,然后获取每一个薪水的上一个数值,结果如下:

+-------+----------+------+-----+
|   name|department|salary|  lag|
+-------+----------+------+-----+
|   Joey|     Sales|  9000| null|
|    Ali|     Sales|  8000| 9000|
|  Elena|     Sales|  8000| 8000|
|  Cindy|     Sales|  7500| 8000|
|    Bob|     Sales|  7000| 7500|
|  Fancy|   Finance| 12000| null|
| George|   Finance| 11000|12000|
|   Davd|   Finance| 10000|11000|
|  Ilaja| Marketing|  8000| null|
|Haffman| Marketing|  7000| 8000|
+-------+----------+------+-----+

与lag() 相对应的获取下一个数值的函数是lead() 。

lead()

lead() 用于获取排序后的数值的下一个,代码如下:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("lead",F.lead("salary",1).over(windowSpec)).show()

按照部门进行分组,并在组内进行薪水排序,然后用lead获取每个薪水值的下一个数值,结果如下:

+-------+----------+------+-----+
|   name|department|salary| lead|
+-------+----------+------+-----+
|   Joey|     Sales|  9000| 8000|
|    Ali|     Sales|  8000| 8000|
|  Elena|     Sales|  8000| 7500|
|  Cindy|     Sales|  7500| 7000|
|    Bob|     Sales|  7000| null|
|  Fancy|   Finance| 12000|11000|
| George|   Finance| 11000|10000|
|   Davd|   Finance| 10000| null|
|  Ilaja| Marketing|  8000| 7000|
|Haffman| Marketing|  7000| null|
+-------+----------+------+-----+

实际业务场景中,假设我们获取了每个月的销售数据,我们可能想要知道,某月份与上一个月或下一个月数据相比怎么样,于是就可以使用lag和lead来进行数据分析了。

2.3 Aggregate Functions

常见的聚合函数有avg, sum, min, max, count, approx_count_distinct()等,我们用如下代码来同时使用这些函数:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
windowSpecAgg  = Window.partitionBy("department")

df.withColumn("row", F.row_number().over(windowSpec)) \
  .withColumn("avg", F.avg("salary").over(windowSpecAgg)) \
  .withColumn("sum", F.sum("salary").over(windowSpecAgg)) \
  .withColumn("min", F.min("salary").over(windowSpecAgg)) \
  .withColumn("max", F.max("salary").over(windowSpecAgg)) \
  .withColumn("count", F.count("salary").over(windowSpecAgg)) \
  .withColumn("distinct_count", F.approx_count_distinct("salary").over(windowSpecAgg)) \
  .show()
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+
|   name|department|salary|row|    avg|  sum|  min|  max|count|distinct_count|
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+
|   Joey|     Sales|  9000|  1| 7900.0|39500| 7000| 9000|    5|             4|
|    Ali|     Sales|  8000|  2| 7900.0|39500| 7000| 9000|    5|             4|
|  Elena|     Sales|  8000|  3| 7900.0|39500| 7000| 9000|    5|             4|
|  Cindy|     Sales|  7500|  4| 7900.0|39500| 7000| 9000|    5|             4|
|    Bob|     Sales|  7000|  5| 7900.0|39500| 7000| 9000|    5|             4|
|  Fancy|   Finance| 12000|  1|11000.0|33000|10000|12000|    3|             3|
| George|   Finance| 11000|  2|11000.0|33000|10000|12000|    3|             3|
|   Davd|   Finance| 10000|  3|11000.0|33000|10000|12000|    3|             3|
|  Ilaja| Marketing|  8000|  1| 7500.0|15000| 7000| 8000|    2|             2|
|Haffman| Marketing|  7000|  2| 7500.0|15000| 7000| 8000|    2|             2|
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+

需要注意的是 approx_count_distinct() 函数适用于窗函数的统计,而在groupby中通常用countDistinct()来代替该函数,用来求组内不重复的数值的条数。approx_count_distinct()取的是近似的数值,不太准确,使用需注意。从结果来看,统计值基本上是按照部门分组,统计组内的salary情况。如果我们只想要保留部门的统计结果,而将每个人的实际情况去掉,可以采用如下代码:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
windowSpecAgg  = Window.partitionBy("department")

df.withColumn("row", F.row_number().over(windowSpec)) \
  .withColumn("avg", F.avg("salary").over(windowSpecAgg)) \
  .withColumn("sum", F.sum("salary").over(windowSpecAgg)) \
  .withColumn("min", F.min("salary").over(windowSpecAgg)) \
  .withColumn("max", F.max("salary").over(windowSpecAgg)) \
  .withColumn("count", F.count("salary").over(windowSpecAgg)) \
  .withColumn("distinct_count", F.approx_count_distinct("salary").over(windowSpecAgg)) \
  .where(F.col("row")==1).select("department","avg","sum","min","max","count","distinct_count") \
  .show()
+----------+-------+-----+-----+-----+-----+--------------+
|department|    avg|  sum|  min|  max|count|distinct_count|
+----------+-------+-----+-----+-----+-----+--------------+
|     Sales| 7900.0|39500| 7000| 9000|    5|             4|
|   Finance|11000.0|33000|10000|12000|    3|             3|
| Marketing| 7500.0|15000| 7000| 8000|    2|             2|
+----------+-------+-----+-----+-----+-----+--------------+

Reference: PySpark–Window Functions

2.4 分组窗

分组窗在实际中用处还是很大的,部分关于Window的知识可移步 Window不同分组窗的使用
假设我们有以下数据:

from pyspark.sql import Row
from pyspark.sql.window import Window
from pyspark.sql.functions import mean, col

row = Row("name", "date", "score")
rdd = sc.parallelize([
    row("Ali", "2020-01-01", 10.0),
    row("Ali", "2020-01-02", 15.0),
    row("Ali", "2020-01-03", 20.0),
    row("Ali", "2020-01-04", 25.0),
    row("Ali", "2020-01-05", 30.0),
    row("Bob", "2020-01-01", 15.0),
    row("Bob", "2020-01-02", 20.0),
    row("Bob", "2020-01-03", 30.0)
])
df = rdd.toDF().withColumn("date", col("date").cast("date"))

我们使用分组的形式计算每个人的平均分,其他数据保留,则可用如下代码:

w1 = Window().partitionBy(col("name"))
df.withColumn("mean1", mean("score").over(w1)).show()
+----+----------+-----+------------------+
|name|      date|score|             mean1|
+----+----------+-----+------------------+
| Bob|2020-01-01| 15.0|21.666666666666668|
| Bob|2020-01-02| 20.0|21.666666666666668|
| Bob|2020-01-03| 30.0|21.666666666666668|
| Ali|2020-01-02| 15.0|              20.0|
| Ali|2020-01-05| 30.0|              20.0|
| Ali|2020-01-01| 10.0|              20.0|
| Ali|2020-01-03| 20.0|              20.0|
| Ali|2020-01-04| 25.0|              20.0|
+----+----------+-----+------------------+

从结果来看,新增加的一列mean1表示每个人所在的分组中所有分数的平均值。当然,你也可以求最大值、最小值或者方差之类的统计值。

下面我们来看一组变形的分组窗:

days = lambda i: i * 86400  # 一天转化为秒单位

w1 = Window().partitionBy(col("name"))
w2 = Window().partitionBy(col("name")).orderBy("date")
w3 = Window().partitionBy(col("name")).orderBy((col("date").cast("timestamp").cast("bigint")/3600/24)).rangeBetween(-4, 0)
w4 = Window().partitionBy(col("name")).orderBy("date").rowsBetween(Window.currentRow, 1)
  • w1就是常规的按照名字进行分组;
  • w2在按照名字分组的基础上,对其组内的日期按照从早到晚进行排序;
  • w3是在w2的基础上,增加了范围限制,限制在从前4天到当前日期的范围内;
  • w4则是在w2的基础上增加了行参数的限制,在当前行到下一行范围内。

是不是还是有些迷糊,不慌,来看下按照这些分组窗统计的结果:

df.withColumn("mean1", mean("score").over(w1))\
  .withColumn("mean2", mean("score").over(w2))\
  .withColumn("mean3", mean("score").over(w3))\
  .withColumn("mean4", mean("score").over(w4))\
  .show()
+----+----------+-----+-----+------------------+------------------+-----+
|name|      date|score|mean1|             mean2|             mean3|mean4|
+----+----------+-----+-----+------------------+------------------+-----+
| Bob|2020-01-01| 15.0| 30.0|              15.0|              15.0| 17.5|
| Bob|2020-01-02| 20.0| 30.0|              17.5|              17.5| 25.0|
| Bob|2020-01-03| 30.0| 30.0|21.666666666666668|21.666666666666668| 32.5|
| Bob|2020-01-04| 35.0| 30.0|              25.0|              25.0| 37.5|
| Bob|2020-01-05| 40.0| 30.0|              28.0|              28.0| 40.0|
| Bob|2020-01-06| 40.0| 30.0|              30.0|              33.0| 40.0|
| Ali|2020-01-01| 10.0| 20.0|              10.0|              10.0| 12.5|
| Ali|2020-01-02| 15.0| 20.0|              12.5|              12.5| 17.5|
| Ali|2020-01-03| 20.0| 20.0|              15.0|              15.0| 22.5|
| Ali|2020-01-04| 25.0| 20.0|              17.5|              17.5| 27.5|
| Ali|2020-01-05| 30.0| 20.0|              20.0|              20.0| 30.0|
+----+----------+-----+-----+------------------+------------------+-----+
  • mean1列很简单,就是每个name分组内所有分数的平均值
  • mean2比较有意思,分组窗是按照name分组后按照日期进行了排序,于是均值是在当前行及前面所有行的范围内进行计算,这个可以看每组最后一个mean2均值,都与mean1均值相等
  • mean3列是在当前行及往前数4天范围内计算均值,如Bob的最后一个mean3值是33,就是从2020-01-02开始计算的
  • mean4列每次只统计当前行和下一行的数值,如果没有下一行则是其本身。

  • Window.unboundedPreceding:前面所有行
  • Window.unboundedFollowing:后面所有行
  • Window.currentRow:当前行

而数值的正负表示往前或往后,大小表示行数


总结

  1. 单独的Window做聚合统计,仅对分组内所有数值进行计算;
  2. 添加orderBy排序的Window分组窗,统计时默认是从前面所有行到当前行进行计算;
  3. rangeBetween结合orderBy可用来限制指定范围内的数据,例如统计一周内数据的场景;
  4. rowsBetween用来限定前后指定行范围内的数据进行统计
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值