Spark:求出分组内的TopN

制作测试数据源:

c1 85
c2 77
c3 88
c1 22
c1 66
c3 95
c3 54
c2 91
c2 66
c1 54
c1 65
c2 41
c4 65

spark scala实现代码:

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object GroupTopN1 {
  System.setProperty("hadoop.home.dir", "D:\\Java_Study\\hadoop-common-2.2.0-bin-master")

  case class Rating(userId: String, rating: Long)

  def main(args: Array[String]) {
    val sparkConf = new SparkConf().setAppName("ALS with ML Pipeline")
    val spark = SparkSession
      .builder()
      .config(sparkConf)
      .master("local")
      .config("spark.sql.warehouse.dir", "/")
      .getOrCreate()

    import spark.implicits._
    import spark.sql

    val lines = spark.read.textFile("C:\\Users\\Administrator\\Desktop\\group.txt")
    val classScores = lines.map(line => Rating(line.split(" ")(0).toString, line.split(" ")(1).toLong))

    classScores.createOrReplaceTempView("tb_test")

    var df = sql(
      s"""|select
          | userId,
          | rating,
          | row_number()over(partition by userId order by rating desc) rn
          |from tb_test
          |having(rn<=3)
          |""".stripMargin)
    df.show()

    spark.stop()
  }
}

打印结果:

+------+------+---+
|userId|rating| rn|
+------+------+---+
|    c1|    85|  1|
|    c1|    66|  2|
|    c1|    65|  3|
|    c4|    65|  1|
|    c3|    95|  1|
|    c3|    88|  2|
|    c3|    54|  3|
|    c2|    91|  1|
|    c2|    77|  2|
|    c2|    66|  3|
+------+------+---+

spark java代码实现:

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Function1;

import javax.management.RuntimeErrorException;
import java.util.List;
import java.util.ArrayList;

public class Test {
    public static void main(String[] args) {
        System.out.println("Hello");
        SparkConf sparkConf = new SparkConf().setAppName("ALS with ML Pipeline");
        SparkSession spark = SparkSession
                .builder()
                .config(sparkConf)
                .master("local")
                .config("spark.sql.warehouse.dir", "/")
                .getOrCreate();


        // Create an RDD
        JavaRDD<String> peopleRDD = spark.sparkContext()
                .textFile("C:\\Users\\Administrator\\Desktop\\group.txt", 1)
                .toJavaRDD();

        // The schema is encoded in a string
        String schemaString = "userId rating";

        // Generate the schema based on the string of schema
        List<StructField> fields = new ArrayList<>();
        StructField field1 = DataTypes.createStructField("userId", DataTypes.StringType, true);
        StructField field2 = DataTypes.createStructField("rating", DataTypes.LongType, true);
        fields.add(field1);
        fields.add(field2);
        StructType schema = DataTypes.createStructType(fields);

        // Convert records of the RDD (people) to Rows
        JavaRDD<Row> rowRDD = peopleRDD.map((Function<String, Row>) record -> {
            String[] attributes = record.split(" ");
            if(attributes.length!=2){
                throw new Exception();
            }
            return RowFactory.create(attributes[0],Long.valueOf( attributes[1].trim()));
        });

        // Apply the schema to the RDD
        Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema);

        peopleDataFrame.createOrReplaceTempView("tb_test");

        Dataset<Row> items = spark.sql("select userId,rating,row_number()over(partition by userId order by rating desc) rn " +
                "from tb_test " +
                "having(rn<=3)");
        items.show();

        spark.stop();
    }
}

输出结果同上边输出结果。

Java 中使用combineByKey实现TopN:

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;

import java.util.*;

public class SparkJava {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().master("local[*]").appName("Spark").getOrCreate();
        final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext());

        List<String> data = Arrays.asList("a,110,a1", "b,122,b1", "c,123,c1", "a,210,a2", "b,212,b2", "a,310,a3", "b,312,b3", "a,410,a4", "b,412,b4");
        JavaRDD<String> javaRDD = ctx.parallelize(data);

        JavaPairRDD<String, Integer> javaPairRDD = javaRDD.mapToPair(new PairFunction<String, String, Integer>() {
            public Tuple2<String, Integer> call(String key) throws Exception {
                return new Tuple2<String, Integer>(key.split(",")[0], Integer.valueOf(key.split(",")[1]));
            }
        });

        final int topN = 3;
        JavaPairRDD<String, List<Integer>> combineByKeyRDD2 = javaPairRDD.combineByKey(new Function<Integer, List<Integer>>() {
            public List<Integer> call(Integer v1) throws Exception {
                List<Integer> items = new ArrayList<Integer>();
                items.add(v1);
                return items;
            }
        }, new Function2<List<Integer>, Integer, List<Integer>>() {
            public List<Integer> call(List<Integer> v1, Integer v2) throws Exception {
                if (v1.size() > topN) {
                    Integer item = Collections.min(v1);
                    v1.remove(item);
                    v1.add(v2);
                }
                return v1;
            }
        }, new Function2<List<Integer>, List<Integer>, List<Integer>>() {
            public List<Integer> call(List<Integer> v1, List<Integer> v2) throws Exception {
                v1.addAll(v2);
                while (v1.size() > topN) {
                    Integer item = Collections.min(v1);
                    v1.remove(item);
                }

                return v1;
            }
        });

        // 由K:String,V:List<Integer> 转化为 K:String,V:Integer
        // old:[(a,[210, 310, 410]), (b,[122, 212, 312]), (c,[123])]
        // new:[(a,210), (a,310), (a,410), (b,122), (b,212), (b,312), (c,123)]
        JavaRDD<Tuple2<String, Integer>> javaTupleRDD = combineByKeyRDD2.flatMap(new FlatMapFunction<Tuple2<String, List<Integer>>, Tuple2<String, Integer>>() {
            public Iterator<Tuple2<String, Integer>> call(Tuple2<String, List<Integer>> stringListTuple2) throws Exception {
                List<Tuple2<String, Integer>> items=new ArrayList<Tuple2<String, Integer>>();
                for(Integer v:stringListTuple2._2){
                    items.add(new Tuple2<String, Integer>(stringListTuple2._1,v));
                }
                return items.iterator();
            }
        });

        JavaRDD<Row> rowRDD = javaTupleRDD.map(new Function<Tuple2<String, Integer>, Row>() {
            public Row call(Tuple2<String, Integer> kv) throws Exception {
                String key = kv._1;
                Integer num = kv._2;

                return RowFactory.create(key, num);
            }
        });

        ArrayList<StructField> fields = new ArrayList<StructField>();
        StructField field = null;
        field = DataTypes.createStructField("key", DataTypes.StringType, true);
        fields.add(field);
        field = DataTypes.createStructField("TopN_values", DataTypes.IntegerType, true);
        fields.add(field);

        StructType schema = DataTypes.createStructType(fields);

        Dataset<Row> df = spark.createDataFrame(rowRDD, schema);
        df.printSchema();
        df.show();

        spark.stop();
    }
}

输出:

root
 |-- key: string (nullable = true)
 |-- TopN_values: integer (nullable = true)

+---+-----------+
|key|TopN_values|
+---+-----------+
|  a|        210|
|  a|        310|
|  a|        410|
|  b|        122|
|  b|        212|
|  b|        312|
|  c|        123|
+---+-----------+

Spark使用combineByKeyWithClassTag函数实现TopN

combineByKeyWithClassTag函数,借助HashSet的排序,此例是取组内最大的N个元素一下是代码:

  • createcombiner就简单的将首个元素装进HashSet然后返回就可以了;
  • mergevalue插入元素之后,如果元素的个数大于N就删除最小的元素;
  • mergeCombiner在合并之后,如果总的个数大于N,就从一次删除最小的元素,知道Hashset内只有N 个元素。
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

import scala.collection.mutable

object Main {
  val N = 3

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local[*]")
      .appName("Spark")
      .getOrCreate()
    val sc = spark.sparkContext
    var SampleDataset = List(
      ("apple.com", 3L),
      ("apple.com", 4L),
      ("apple.com", 1L),
      ("apple.com", 9L),
      ("google.com", 4L),
      ("google.com", 1L),
      ("google.com", 2L),
      ("google.com", 3L),
      ("google.com", 11L),
      ("google.com", 32L),
      ("slashdot.org", 11L),
      ("slashdot.org", 12L),
      ("slashdot.org", 13L),
      ("slashdot.org", 14L),
      ("slashdot.org", 15L),
      ("slashdot.org", 16L),
      ("slashdot.org", 17L),
      ("slashdot.org", 18L),
      ("microsoft.com", 5L),
      ("microsoft.com", 2L),
      ("microsoft.com", 6L),
      ("microsoft.com", 9L),
      ("google.com", 4L))
    val urdd: RDD[(String, Long)] = sc.parallelize(SampleDataset).map((t) => (t._1, t._2))
    var topNs = urdd.combineByKeyWithClassTag(
      //createCombiner
      (firstInt: Long) => {
        var uset = new mutable.TreeSet[Long]()
        uset += firstInt
      },
      // mergeValue
      (uset: mutable.TreeSet[Long], value: Long) => {
        uset += value
        while (uset.size > N) {
          uset.remove(uset.min)
        }
        uset
      },
      //mergeCombiners
      (uset1: mutable.TreeSet[Long], uset2: mutable.TreeSet[Long]) => {
        var resultSet = uset1 ++ uset2
        while (resultSet.size > N) {
          resultSet.remove(resultSet.min)
        }
        resultSet
      }
    )
    import spark.implicits._
    topNs.flatMap(rdd => {
      var uset = new mutable.HashSet[String]()
      for (i <- rdd._2.toList) {
        uset += rdd._1 + "/" + i.toString
      }
      uset
    }).map(rdd => {
      (rdd.split("/")(0), rdd.split("/")(1))
    }).toDF("key", "TopN_values").show()
  }
}

参考《https://blog.csdn.net/gpwner/article/details/78455234》

输出结果:

+-------------+-----------+
|          key|TopN_values|
+-------------+-----------+
|   google.com|          4|
|   google.com|         11|
|   google.com|         32|
|microsoft.com|          9|
|microsoft.com|          6|
|microsoft.com|          5|
|    apple.com|          4|
|    apple.com|          9|
|    apple.com|          3|
| slashdot.org|         16|
| slashdot.org|         17|
| slashdot.org|         18|
+-------------+-----------+

 

### 回答1: Spark RDD中分组TopN案例是指在一个RDD中,根据某个键值进行分组,然后对每个组内的数据进行排序,取出每个组内的前N个数据。这种操作在数据分析和处理中非常常见,可以用于统计每个地区的销售额排名前N的产品、每个用户的消费排名前N的商品等。 优化方面,可以考虑使用Spark SQL或DataFrame来实现分组TopN操作,因为它们提供了更高级的API和优化技术,可以更快速地处理大规模数据。另外,可以使用分布式缓存技术将数据缓存到内存中,以加快数据访问速度。还可以使用分区和并行计算等技术来提高计算效率。 ### 回答2: Spark RDD中分组取Top N的案例可以是对一个大数据集中的用户数据进行分组,然后取每个组中消费金额最高的前N个用户。这个案例可以通过以下步骤来实现: 1. 将用户数据载入Spark RDD中,每个数据记录包含用户ID和消费金额。 2. 使用groupBy函数将RDD按照用户ID进行分组,得到一个以用户ID为key,包含相同用户ID的数据记录的value的RDD。 3. 对每个分组的value调用top函数,指定N的值,以获取每个分组中消费金额最高的前N个用户。 4. 可以将每个分组中Top N的用户使用flatMap函数展开为多个记录,并可以添加一个新的字段表示该记录属于哪个分组。 5. 最后,可以使用collect函数将结果转化为数组或者保存到文件或数据库中。 在这个案例中,进行优化的关键是减少数据的传输和处理开销。可以使用缓存或持久化函数对RDD进行优化,以减少重复计算。另外,可以使用并行操作来加速计算,如使用并行的排序算法,或向集群中的多个节点分发计算任务。 对于分组取Top N的优化,还可以考虑使用局部聚合和全局聚合的策略。首先对每个分组内的数据进行局部聚合,例如计算每个分组的前M个最大值。然后,对所有分组的局部聚合结果进行全局聚合,例如计算所有分组的前K个最大值。 另一个优化策略是使用采样技术,例如随机采样或分层采样,以减少需要处理的数据量。 最后,还可以考虑使用Spark的其他高级功能,如Broadcast变量共享数据,使用累加器进行计数或统计等,来进一步提高性能和效率。 ### 回答3: Spark RDD 是 Spark 提供的一种基于内存的分布式数据处理模型,其核心数据结构是弹性分布式数据集(RDD)。 在 Spark RDD 中,分组TopN 是一种常见的需,即对 RDD 中的数据按某个字段进行分组,并取出每个分组中字段值最大的前 N 个数据。 下面以一个示例来说明分组TopN 的用法和优化方法: 假设有一个包含学生信息的 RDD,其中每条数据都包括学生的学科和分数,我们希望对每个学科取出分数最高的前 3 名学生。 ```python # 创建示例数据 data = [ ("语文", 80), ("数学", 90), ("语文", 85), ("数学", 95), ("语文", 75), ("数学", 92), ("英语", 88) ] rdd = sc.parallelize(data) # 分组TopN top3 = rdd.groupByKey().mapValues(lambda x: sorted(x, reverse=True)[:3]) # 输出结果 for subject, scores in top3.collect(): print(subject, scores) # 输出结果: # 数学 [95, 92, 90] # 语文 [85, 80, 75] # 英语 [88] ``` 在上述代码中,我们先使用 `groupByKey()` 对 RDD 进行分组操作,然后使用 `mapValues()` 对每个分组内的数据进行排序并取前 3 个值。 这种方式的优化点在于,通过将分组操作和取 TopN 操作分开,可以减轻数据倾斜的问题。同时,对每个分组进行排序会占用大量计算资源,可以考虑将数据转换为 Pair RDD,并利用 Spark 提供的 `top()` 算子来优化取 TopN 的操作。 ```python # 转换为 Pair RDD pair_rdd = rdd.map(lambda x: (x[0], x[1])) # 分组并取TopN,使用top()算子代替排序操作 top3 = pair_rdd.groupByKey().mapValues(lambda x: sorted(x, reverse=True)).mapValues(lambda x: x[:3]) # 输出结果 for subject, scores in top3.collect(): print(subject, scores) # 输出结果: # 数学 [95, 92, 90] # 语文 [85, 80, 75] # 英语 [88] ``` 通过以上优化,我们可以更好地处理大规模数据集下的分组TopN 的需,提高计算性能和资源利用率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值