使用Spark实现TopN有几种方式清楚吗?

在实际开发过程中,我们会经常碰到求TopN这样常见的需求,那在Spark中,是如何实现求TopN呢?带着这个问题,就来看一下TopN的实现方式都有哪些!

一、使用groupByKey

1.1需求说明

品类是指产品的分类,大型电商网站品类分多级,咱们的项目中品类只有一级,不同的公司可能对热门的定义不一样。我们按照每个品类的点击、下单、支付的量(次数)来统计热门品类。鞋 点击数 下单数 支付数 衣服 点击数 下单数 支付数 电脑 点击数 下单数 支付数 例如,综合排名 = 点击数20% + 下单数30% + 支付数*50%。为了更好的泛用性,当前案例按照点击次数进行排序,如果点击相同,按照下单数,如果下单还是相同,按照支付数。

1.2思路

1. 将数据通过key(品类)进行分组。

JavaPairRDD<String, Iterable<CategoryCountInfo>> groupByRDD = categoryCountInfoJavaRDD.groupBy(new Function<CategoryCountInfo, String>() {

@Override

public String call(CategoryCountInfo v1) throws Exception {

return v1.getCategoryId();

}

});

2. 通过map算子,对迭代器中的同一品类的数据进行聚合。

JavaRDD<CategoryCountInfo> countInfoJavaRDD = groupByRDD.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> v1) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(v1._1, 0L, 0L, 0L);

Iterable<CategoryCountInfo> countInfos = v1._2;

for (CategoryCountInfo countInfo : countInfos) {

result.setClickCount(result.getClickCount() + countInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + countInfo.getOrderCount());

result.setPayCount(result.getPayCount() + countInfo.getPayCount());

}

return result;

}

});

3. 通过sortby算子排序取TopN。

JavaRDD<CategoryCountInfo> result = countInfoJavaRDD.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo v1) throws Exception {

return v1;

}

}, false, 2);

return result.top(num);

1.3环境和数据准备

数据格式

采用电商网站用户行为数据,主要包含用户的4种行为:搜索、点击、下单、支付。

(1)数据采用_分割字段。

(2)每一行表示用户的一个行为,所以每一行只能是四种行为中的一种。

(3)如果点击的品类id和产品id是-1表示这次不是点击。

(4)针对下单行为,一次可以下单多个产品,所以品类id和产品id都是多个,id之间使用逗号分割。如果本次不是下单行为,则他们相关数据用null来表示。

(5)支付行为和下单行为格式类似。

依赖

<dependency>

<groupId>org.apache.spark</groupId>

<artifactId>spark-core_2.12</artifactId>

<version>3.1.3</version>

</dependency>

<dependency>

<groupId>com.jolbox</groupId>

<artifactId>bonecp</artifactId>

<version>0.8.0.RELEASE</version>

</dependency>

<dependency>

<groupId>org.slf4j</groupId>

<artifactId>slf4j-log4j12</artifactId>

<version>1.7.7</version>

</dependency>

<dependency>

<groupId>org.projectlombok</groupId>

<artifactId>lombok</artifactId>

<version>1.18.22</version>

</dependency>

1.4完整代码和运行结果

public static void main(String[] args) throws ClassNotFoundException {

SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")

.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

.registerKryoClasses(new Class[]{

Class.forName("com.atguigu.CategoryCountInfo"),

Class.forName("com.atguigu.UserVisitAction")

});

JavaSparkContext sc = new JavaSparkContext(conf);

JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");

JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);

myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

//topN(result,5).forEach(System.out::println);

myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);

// result. collect().forEach(System.out::println);

sc.stop();

}

public static List<CategoryCountInfo> myGroupByKey(JavaRDD<CategoryCountInfo> categoryCountInfoJavaRDD, int num) {

JavaPairRDD<String, Iterable<CategoryCountInfo>> groupByRDD = categoryCountInfoJavaRDD.groupBy(new Function<CategoryCountInfo, String>() {

@Override

public String call(CategoryCountInfo v1) throws Exception {

return v1.getCategoryId();

}

});

JavaRDD<CategoryCountInfo> countInfoJavaRDD = groupByRDD.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> v1) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(v1._1, 0L, 0L, 0L);

Iterable<CategoryCountInfo> countInfos = v1._2;

for (CategoryCountInfo countInfo : countInfos) {

result.setClickCount(result.getClickCount() + countInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + countInfo.getOrderCount());

result.setPayCount(result.getPayCount() + countInfo.getPayCount());

}

return result;

}

});

// CategoryCountInfo需要能够比较大小

JavaRDD<CategoryCountInfo> result = countInfoJavaRDD.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo v1) throws Exception {

return v1;

}

}, false, 2);

return result;

}

二、采用两阶段聚合

2.1思路

第一阶段给每个key加上一个随机值前缀,然后进行局部的聚合操作,并去除前缀key。

Random random = new Random();

JavaPairRDD<Tuple2, CategoryCountInfo> app = categoryCountInfoJavaRDD

.mapToPair(new PairFunction<CategoryCountInfo, Tuple2, CategoryCountInfo>() {

@Override

public Tuple2<Tuple2, CategoryCountInfo> call(CategoryCountInfo categoryCountInfo) throws Exception {

//加前缀

return new Tuple2<>(new Tuple2<>(random.nextInt(10), categoryCountInfo.getCategoryId()), categoryCountInfo);

}

});

app.groupByKey()

.map(new Function<Tuple2<Tuple2, Iterable<CategoryCountInfo>>, Tuple2<String, CategoryCountInfo>>() {

@Override

public Tuple2<String, CategoryCountInfo> call(Tuple2<Tuple2, Iterable<CategoryCountInfo>> tuple2IterableTuple2) throws Exception {

//局部聚合操作

CategoryCountInfo result = new CategoryCountInfo(tuple2IterableTuple2._1._2.toString(), 0l, 0l, 0l);

// result.setCategoryId(tuple2IterableTuple2._1.toString());

for (CategoryCountInfo categoryCountInfo : tuple2IterableTuple2._2) {

// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());

}

//去除前缀key

return new Tuple2<>(result.getCategoryId(), result);

}

})

得到的新rdd,进行分组,然后进行全局的聚合操作。

.groupBy(new Function<Tuple2<String, CategoryCountInfo>, String>() {

@Override

public String call(Tuple2<String, CategoryCountInfo> stringCategoryCountInfoTuple2) throws Exception {

return stringCategoryCountInfoTuple2._1;

}

})

.map(new Function<Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableTuple2) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);

// result.setCategoryId(tuple2IterableTuple2._1.toString());

for (Tuple2<String, CategoryCountInfo> categoryCountInfo : stringIterableTuple2._2) {

// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo._2.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo._2.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo._2.getPayCount());

}

return result;

}

});

排序后, 获取topN的数据。

List<CategoryCountInfo> top = map

.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo;

}

}, false, 2)

.top(num);

2.2代码和结果

public static void main(String[] args) throws ClassNotFoundException {

SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")

.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

.registerKryoClasses(new Class[]{

Class.forName("com.atguigu.CategoryCountInfo"),

Class.forName("com.atguigu.UserVisitAction")

});

;

JavaSparkContext sc = new JavaSparkContext(conf);

JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");

JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);

myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

// topN(result,5).forEach(System.out::println);

myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);

// result. collect().forEach(System.out::println);

sc.stop();

}

public static List<CategoryCountInfo> myTowStageWithGroup(JavaRDD<CategoryCountInfo> categoryCountInfoJavaRDD, int num) {

Random random = new Random();

JavaPairRDD<Tuple2, CategoryCountInfo> app = categoryCountInfoJavaRDD

.mapToPair(new PairFunction<CategoryCountInfo, Tuple2, CategoryCountInfo>() {

@Override

public Tuple2<Tuple2, CategoryCountInfo> call(CategoryCountInfo categoryCountInfo) throws Exception {

return new Tuple2<>(new Tuple2<>(random.nextInt(10), categoryCountInfo.getCategoryId()), categoryCountInfo);

}

});

ArrayList<CategoryCountInfo> tuple2s = new ArrayList<>();

// app.groupByKey().collect().forEach(System.out::println);

JavaPairRDD<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableJavaPairRDD = app

.groupByKey()

.map(new Function<Tuple2<Tuple2, Iterable<CategoryCountInfo>>, Tuple2<String, CategoryCountInfo>>() {

@Override

public Tuple2<String, CategoryCountInfo> call(Tuple2<Tuple2, Iterable<CategoryCountInfo>> tuple2IterableTuple2) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(tuple2IterableTuple2._1._2.toString(), 0l, 0l, 0l);

// result.setCategoryId(tuple2IterableTuple2._1.toString());

for (CategoryCountInfo categoryCountInfo : tuple2IterableTuple2._2) {

// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());

}

return new Tuple2<>(result.getCategoryId(), result);

}

})

.groupBy(new Function<Tuple2<String, CategoryCountInfo>, String>() {

@Override

public String call(Tuple2<String, CategoryCountInfo> stringCategoryCountInfoTuple2) throws Exception {

return stringCategoryCountInfoTuple2._1;

}

});

JavaRDD<CategoryCountInfo> map = stringIterableJavaPairRDD

.map(new Function<Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableTuple2) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);

// result.setCategoryId(tuple2IterableTuple2._1.toString());

for (Tuple2<String, CategoryCountInfo> categoryCountInfo : stringIterableTuple2._2) {

// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo._2.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo._2.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo._2.getPayCount());

}

return result;

}

});

List<CategoryCountInfo> top = map

.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo;

}

}, false, 2)

.top(num);

return top;

}

三、先计算每个分区的TopN,再计算全局TopN

3.1思路

对于每一个key获取每个分区中的TopN。

JavaRDD<CategoryCountInfo> map = app

.mapPartitions(new FlatMapFunction<Iterator<CategoryCountInfo>, CategoryCountInfo>() {

@Override

public Iterator<CategoryCountInfo> call(Iterator<CategoryCountInfo> categoryCountInfoIterator) throws Exception {

TreeMap<String, CategoryCountInfo> stringCategoryCountInfoHashMap = new TreeMap<>();

while (categoryCountInfoIterator.hasNext()) {

CategoryCountInfo next = categoryCountInfoIterator.next();

if (!stringCategoryCountInfoHashMap.containsKey(next.getCategoryId())) {

stringCategoryCountInfoHashMap.put(next.getCategoryId(), next);

} else {

CategoryCountInfo categoryCountInfo = stringCategoryCountInfoHashMap.get(next.getCategoryId());

categoryCountInfo.setClickCount(next.getClickCount() + categoryCountInfo.getClickCount());

categoryCountInfo.setOrderCount(next.getOrderCount() + categoryCountInfo.getOrderCount());

categoryCountInfo.setPayCount(next.getPayCount() + categoryCountInfo.getPayCount());

stringCategoryCountInfoHashMap.put(next.getCategoryId(), categoryCountInfo);

}

}

//这里的返回值中做了局部topn,如果数据量较小的话,或者对topn中每个记录具体值有要求可以不限制。

return stringCategoryCountInfoHashMap.values().stream().limit(first).iterator();

}

})

做全局的数据聚合操作。

.groupBy(new Function<CategoryCountInfo, String>() {

@Override

public String call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo.getCategoryId();

}

})

.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> stringIterableTuple2) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);

//result.setCategoryId(tuple2IterableTuple2._1.toString());

for (CategoryCountInfo categoryCountInfo : stringIterableTuple2._2) {

//CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());

}

return result;

}

});

然后全局排序后取topN。

List<CategoryCountInfo> top = map

.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo;

}

}, false, 2)

.top(second);

3.2代码和结果

public static void main(String[] args) throws ClassNotFoundException {

SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")

.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

.registerKryoClasses(new Class[]{

Class.forName("com.atguigu.CategoryCountInfo"),

Class.forName("com.atguigu.UserVisitAction")

});

;

JavaSparkContext sc = new JavaSparkContext(conf);

JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");

JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);

myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);

// topN(result,5).forEach(System.out::println);

myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);

// result. collect().forEach(System.out::println);

// 4. 关闭sc

sc.stop();

}

public static List<CategoryCountInfo> myPartitionKeyBy(JavaRDD<CategoryCountInfo> app, int first, int second) {

JavaRDD<CategoryCountInfo> map = app

.mapPartitions(new FlatMapFunction<Iterator<CategoryCountInfo>, CategoryCountInfo>() {

@Override

public Iterator<CategoryCountInfo> call(Iterator<CategoryCountInfo> categoryCountInfoIterator) throws Exception {

TreeMap<String, CategoryCountInfo> stringCategoryCountInfoHashMap = new TreeMap<>();

while (categoryCountInfoIterator.hasNext()) {

CategoryCountInfo next = categoryCountInfoIterator.next();

if (!stringCategoryCountInfoHashMap.containsKey(next.getCategoryId())) {

stringCategoryCountInfoHashMap.put(next.getCategoryId(), next);

} else {

CategoryCountInfo categoryCountInfo = stringCategoryCountInfoHashMap.get(next.getCategoryId());

categoryCountInfo.setClickCount(next.getClickCount() + categoryCountInfo.getClickCount());

categoryCountInfo.setOrderCount(next.getOrderCount() + categoryCountInfo.getOrderCount());

categoryCountInfo.setPayCount(next.getPayCount() + categoryCountInfo.getPayCount());

stringCategoryCountInfoHashMap.put(next.getCategoryId(), categoryCountInfo);

}

}

return stringCategoryCountInfoHashMap.values().stream().limit(first).iterator();

}

})

.groupBy(new Function<CategoryCountInfo, String>() {

@Override

public String call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo.getCategoryId();

}

})

.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> stringIterableTuple2) throws Exception {

CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);

// result.setCategoryId(tuple2IterableTuple2._1.toString());

for (CategoryCountInfo categoryCountInfo : stringIterableTuple2._2) {

// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();

result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());

result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());

result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());

}

return result;

}

});

List<CategoryCountInfo> top = map

.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {

@Override

public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {

return categoryCountInfo;

}

}, false, 2)

.top(second);

return top;

}

总结

方式一:

1. groupByKey会将相同key的所有value全部加载到内存进行处理,当value特别多的时候可能出现OOM异常。

2. groupByKey会将所有的value数据均发送给下一个RDD,性能比较低,因为在实际聚合操作中只需要部分数据。

方式二:

1. 对于聚合类Shuffle操作(groupByKey,reduceByKey等)产生的问题能够很好的解决。

2. 对于非聚合类(join等)产生的问题很难使用该方法解决。

方式三:

1. 解决了方式1实现方式的两个缺点。

2. 都采用了先分区内预聚合,然后进行全局聚合的思想。

实现topN的方法有很多,其实归纳一下可以分成俩种:1. 最容易想到的,也就是没有优化,直接分组聚合取topN。2. 对topN需求进行优化,使用算子优化,或者逻辑优化。把聚合操作分成多步操作,防止oom,提高运行效率。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值