这篇文章主要介绍在Spark中如何分组取TopN元素的两种方法:
- 第一种方法基于Spark SQL的窗口函数实现,
- 第二种方法基于原生的RDD接口实现。
构造数据
首先我们构造一份班级的成绩数据,这份数据有三列组成,第一列是考试科目category,第二列是学生的名字name,第三列是学生的成绩。如下:
val df = spark.createDataFrame(Seq(
("A", "Tom", 78),
("B", "James", 47),
("A", "Jim", 43),
("C", "James", 89),
("A", "Lee", 93),
("C", "Jim", 65),
("A", "James", 10),
("C", "Lee", 39),
("B", "Tom", 99),
("C", "Tom", 53),
("B", "Lee", 100),
("B", "Jim", 100)
)).toDF("category", "name", "score")
df.show(false)
输出:
+--------+-----+-----+ |category|name |score| +--------+-----+-----+ |A |Tom |78 | |B |James|47 | |A |Jim |43 | |C |James|89 | |A |Lee |93 | |C |Jim |65 | |A |James|10 | |C |Lee |39 | |B |Tom |99 | |C |Tom |53 | |B |Lee |100 | |B |Jim |100 | +--------+-----+-----+ |
我们要实现的目标是:取出每个科目下成绩排名前三的学生
1、使用窗口函数实现
Spark SQL从1.4开始支持窗口分析函数,我们可以使用窗口函数row_number来进行分组排序,然后在对每个分区取出TopN个元素。row_number函数作用于一个分区,并为该分区中的每条记录生成一个从1开始递增的序列号,这样在外层循环就可以通过过滤该序列号来获取特定的数据。
①使用窗口函数取TopN
val N = 3
val window = Window.partitionBy(col("category")).orderBy(col("score").desc)
val top3DF = df.withColumn("topn", row_number().over(window)).where(col("topn") <= N)
top3DF.show(false)
输出:
+--------+-----+-----+----+ |category|name |score|topn| +--------+-----+-----+----+ |B |Lee |100 |1 | |B |Jim |100 |2 | |B |Tom |99 |3 | |C |James|89 |1 | |C |Jim |65 |2 | |C |Tom |53 |3 | |A |Lee |93 |1 | |A |Tom |78 |2 | |A |Jim |43 |3 | +--------+-----+-----+----+ |
②也可以使用sql查询的方式
df.createOrReplaceTempView("grade")
val sql = "select category, name, score from (select category, name, score, row_number() over (partition by category order by score desc ) rank from grade) g where g.rank <= 3".stripMargin
val top3DFBySQL = spark.sql(sql)
top3DFBySQL.show(false)
2、使用元素RDD接口实现
使用原生RDD接口来获取TopN元素主要需要以下三个步骤:
- 将数据按指定的标准分组。比如在本例中需要按“category”分组
- 对每个分组中的元素进行排序,然后取TopN个元素
- 将以上数据Flat展开,恢复为原有格式
// 使用RDD取Top
// step 1: 分组
val groupRDD = df.rdd.map(x => (x.getString(0), (x.getString(1), x.getInt(2)))).groupByKey()
// step 2: 排序并取TopN
val N = 3
val sortedRDD = groupRDD.map(x => {
val rawRows = x._2.toBuffer
val sortedRows = rawRows.sortBy(_._2.asInstanceOf[Integer])
// 取TopN元素
if (sortedRows.size > N) {
sortedRows.remove(0, (sortedRows.length - N))
}
(x._1, sortedRows.toIterator)
})
// step 3: 展开
val flatRDD = sortedRDD.flatMap(x => {
val y = x._2
for (w <- y) yield (x._1, w._1, w._2)
})
flatRDD.toDF("category", "name", "score").show(false)
输出:
+--------+-----+-----+ |category|name |score| +--------+-----+-----+ |B |Lee |100 | |B |Jim |100 | |B |Tom |99 | |C |James|89 | |C |Jim |65 | |C |Tom |53 | |A |Lee |93 | |A |Tom |78 | |A |Jim |43 | +--------+-----+-----+ --------------------- |