Spark分组取TopN

这篇文章主要介绍在Spark中如何分组取TopN元素的两种方法:

  1. 第一种方法基于Spark SQL的窗口函数实现,
  2. 第二种方法基于原生的RDD接口实现。

构造数据

首先我们构造一份班级的成绩数据,这份数据有三列组成,第一列是考试科目category,第二列是学生的名字name,第三列是学生的成绩。如下:

val df = spark.createDataFrame(Seq(
  ("A", "Tom", 78),
  ("B", "James", 47),
  ("A", "Jim", 43),
  ("C", "James", 89),
  ("A", "Lee", 93),
  ("C", "Jim", 65),
  ("A", "James", 10),
  ("C", "Lee", 39),
  ("B", "Tom", 99),
  ("C", "Tom", 53),
  ("B", "Lee", 100),
  ("B", "Jim", 100)
)).toDF("category", "name", "score")
df.show(false)

输出:

+--------+-----+-----+
|category|name |score|
+--------+-----+-----+
|A       |Tom  |78   |
|B       |James|47   |
|A       |Jim  |43   |
|C       |James|89   |
|A       |Lee  |93   |
|C       |Jim  |65   |
|A       |James|10   |
|C       |Lee  |39   |
|B       |Tom  |99   |
|C       |Tom  |53   |
|B       |Lee  |100  |
|B       |Jim  |100  |
+--------+-----+-----+


我们要实现的目标是:取出每个科目下成绩排名前三的学生

 

1、使用窗口函数实现

Spark SQL从1.4开始支持窗口分析函数,我们可以使用窗口函数row_number来进行分组排序,然后在对每个分区取出TopN个元素。row_number函数作用于一个分区,并为该分区中的每条记录生成一个从1开始递增的序列号,这样在外层循环就可以通过过滤该序列号来获取特定的数据。

①使用窗口函数取TopN

val N = 3
val window = Window.partitionBy(col("category")).orderBy(col("score").desc)
val top3DF = df.withColumn("topn", row_number().over(window)).where(col("topn") <= N)
top3DF.show(false)

输出:

+--------+-----+-----+----+
|category|name |score|topn|
+--------+-----+-----+----+
|B       |Lee  |100  |1   |
|B       |Jim  |100  |2   |
|B       |Tom  |99   |3   |
|C       |James|89   |1   |
|C       |Jim  |65   |2   |
|C       |Tom  |53   |3   |
|A       |Lee  |93   |1   |
|A       |Tom  |78   |2   |
|A       |Jim  |43   |3   |
+--------+-----+-----+----+

②也可以使用sql查询的方式

df.createOrReplaceTempView("grade")
val sql = "select category, name, score from (select category, name, score, row_number() over (partition by category order by score desc ) rank from grade) g where g.rank <= 3".stripMargin
val top3DFBySQL = spark.sql(sql)
top3DFBySQL.show(false)

2、使用元素RDD接口实现

使用原生RDD接口来获取TopN元素主要需要以下三个步骤:

  1. 将数据按指定的标准分组。比如在本例中需要按“category”分组
  2. 对每个分组中的元素进行排序,然后取TopN个元素
  3. 将以上数据Flat展开,恢复为原有格式
// 使用RDD取Top
// step 1: 分组
val groupRDD = df.rdd.map(x => (x.getString(0), (x.getString(1), x.getInt(2)))).groupByKey()

// step 2: 排序并取TopN
val N = 3
val sortedRDD = groupRDD.map(x => {
  val rawRows = x._2.toBuffer
  val sortedRows = rawRows.sortBy(_._2.asInstanceOf[Integer])
  // 取TopN元素
  if (sortedRows.size > N) {
    sortedRows.remove(0, (sortedRows.length - N))
  }
  (x._1, sortedRows.toIterator)
})

// step 3: 展开
val flatRDD = sortedRDD.flatMap(x => {
  val y = x._2
  for (w <- y) yield (x._1, w._1, w._2)
})

flatRDD.toDF("category", "name", "score").show(false)

输出:

+--------+-----+-----+
|category|name |score|
+--------+-----+-----+
|B       |Lee  |100  |
|B       |Jim  |100  |
|B       |Tom  |99   |
|C       |James|89   |
|C       |Jim  |65   |
|C       |Tom  |53   |
|A       |Lee  |93   |
|A       |Tom  |78   |
|A       |Jim  |43   |
+--------+-----+-----+
--------------------- 


转载:https://blog.csdn.net/Xiejingfa/article/details/79831938
 

展开阅读全文

没有更多推荐了,返回首页