在实际开发过程中,我们会经常碰到求TopN这样常见的需求,那在Spark中,是如何实现求TopN呢?带着这个问题,就来看一下TopN的实现方式都有哪些!
一、使用groupByKey
1.1需求说明
品类是指产品的分类,大型电商网站品类分多级,咱们的项目中品类只有一级,不同的公司可能对热门的定义不一样。我们按照每个品类的点击、下单、支付的量(次数)来统计热门品类。鞋 点击数 下单数 支付数 衣服 点击数 下单数 支付数 电脑 点击数 下单数 支付数 例如,综合排名 = 点击数20% + 下单数30% + 支付数*50%。为了更好的泛用性,当前案例按照点击次数进行排序,如果点击相同,按照下单数,如果下单还是相同,按照支付数。
1.2思路
1. 将数据通过key(品类)进行分组。
JavaPairRDD<String, Iterable<CategoryCountInfo>> groupByRDD = categoryCountInfoJavaRDD.groupBy(new Function<CategoryCountInfo, String>() {
@Override
public String call(CategoryCountInfo v1) throws Exception {
return v1.getCategoryId();
}
});
2. 通过map算子,对迭代器中的同一品类的数据进行聚合。
JavaRDD<CategoryCountInfo> countInfoJavaRDD = groupByRDD.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> v1) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(v1._1, 0L, 0L, 0L);
Iterable<CategoryCountInfo> countInfos = v1._2;
for (CategoryCountInfo countInfo : countInfos) {
result.setClickCount(result.getClickCount() + countInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + countInfo.getOrderCount());
result.setPayCount(result.getPayCount() + countInfo.getPayCount());
}
return result;
}
});
3. 通过sortby算子排序取TopN。
JavaRDD<CategoryCountInfo> result = countInfoJavaRDD.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo v1) throws Exception {
return v1;
}
}, false, 2);
return result.top(num);
1.3环境和数据准备
数据格式
采用电商网站用户行为数据,主要包含用户的4种行为:搜索、点击、下单、支付。
(1)数据采用_分割字段。
(2)每一行表示用户的一个行为,所以每一行只能是四种行为中的一种。
(3)如果点击的品类id和产品id是-1表示这次不是点击。
(4)针对下单行为,一次可以下单多个产品,所以品类id和产品id都是多个,id之间使用逗号分割。如果本次不是下单行为,则他们相关数据用null来表示。
(5)支付行为和下单行为格式类似。
依赖
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.12</artifactId>
<version>3.1.3</version>
</dependency>
<dependency>
<groupId>com.jolbox</groupId>
<artifactId>bonecp</artifactId>
<version>0.8.0.RELEASE</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.7</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.22</version>
</dependency>
1.4完整代码和运行结果
public static void main(String[] args) throws ClassNotFoundException {
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.registerKryoClasses(new Class[]{
Class.forName("com.atguigu.CategoryCountInfo"),
Class.forName("com.atguigu.UserVisitAction")
});
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");
JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);
myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
//topN(result,5).forEach(System.out::println);
myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);
// result. collect().forEach(System.out::println);
sc.stop();
}
public static List<CategoryCountInfo> myGroupByKey(JavaRDD<CategoryCountInfo> categoryCountInfoJavaRDD, int num) {
JavaPairRDD<String, Iterable<CategoryCountInfo>> groupByRDD = categoryCountInfoJavaRDD.groupBy(new Function<CategoryCountInfo, String>() {
@Override
public String call(CategoryCountInfo v1) throws Exception {
return v1.getCategoryId();
}
});
JavaRDD<CategoryCountInfo> countInfoJavaRDD = groupByRDD.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> v1) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(v1._1, 0L, 0L, 0L);
Iterable<CategoryCountInfo> countInfos = v1._2;
for (CategoryCountInfo countInfo : countInfos) {
result.setClickCount(result.getClickCount() + countInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + countInfo.getOrderCount());
result.setPayCount(result.getPayCount() + countInfo.getPayCount());
}
return result;
}
});
// CategoryCountInfo需要能够比较大小
JavaRDD<CategoryCountInfo> result = countInfoJavaRDD.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo v1) throws Exception {
return v1;
}
}, false, 2);
return result;
}
二、采用两阶段聚合
2.1思路
第一阶段给每个key加上一个随机值前缀,然后进行局部的聚合操作,并去除前缀key。
Random random = new Random();
JavaPairRDD<Tuple2, CategoryCountInfo> app = categoryCountInfoJavaRDD
.mapToPair(new PairFunction<CategoryCountInfo, Tuple2, CategoryCountInfo>() {
@Override
public Tuple2<Tuple2, CategoryCountInfo> call(CategoryCountInfo categoryCountInfo) throws Exception {
//加前缀
return new Tuple2<>(new Tuple2<>(random.nextInt(10), categoryCountInfo.getCategoryId()), categoryCountInfo);
}
});
app.groupByKey()
.map(new Function<Tuple2<Tuple2, Iterable<CategoryCountInfo>>, Tuple2<String, CategoryCountInfo>>() {
@Override
public Tuple2<String, CategoryCountInfo> call(Tuple2<Tuple2, Iterable<CategoryCountInfo>> tuple2IterableTuple2) throws Exception {
//局部聚合操作
CategoryCountInfo result = new CategoryCountInfo(tuple2IterableTuple2._1._2.toString(), 0l, 0l, 0l);
// result.setCategoryId(tuple2IterableTuple2._1.toString());
for (CategoryCountInfo categoryCountInfo : tuple2IterableTuple2._2) {
// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());
}
//去除前缀key
return new Tuple2<>(result.getCategoryId(), result);
}
})
得到的新rdd,进行分组,然后进行全局的聚合操作。
.groupBy(new Function<Tuple2<String, CategoryCountInfo>, String>() {
@Override
public String call(Tuple2<String, CategoryCountInfo> stringCategoryCountInfoTuple2) throws Exception {
return stringCategoryCountInfoTuple2._1;
}
})
.map(new Function<Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableTuple2) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);
// result.setCategoryId(tuple2IterableTuple2._1.toString());
for (Tuple2<String, CategoryCountInfo> categoryCountInfo : stringIterableTuple2._2) {
// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo._2.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo._2.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo._2.getPayCount());
}
return result;
}
});
排序后, 获取topN的数据。
List<CategoryCountInfo> top = map
.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo;
}
}, false, 2)
.top(num);
2.2代码和结果
public static void main(String[] args) throws ClassNotFoundException {
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.registerKryoClasses(new Class[]{
Class.forName("com.atguigu.CategoryCountInfo"),
Class.forName("com.atguigu.UserVisitAction")
});
;
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");
JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);
myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
// topN(result,5).forEach(System.out::println);
myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);
// result. collect().forEach(System.out::println);
sc.stop();
}
public static List<CategoryCountInfo> myTowStageWithGroup(JavaRDD<CategoryCountInfo> categoryCountInfoJavaRDD, int num) {
Random random = new Random();
JavaPairRDD<Tuple2, CategoryCountInfo> app = categoryCountInfoJavaRDD
.mapToPair(new PairFunction<CategoryCountInfo, Tuple2, CategoryCountInfo>() {
@Override
public Tuple2<Tuple2, CategoryCountInfo> call(CategoryCountInfo categoryCountInfo) throws Exception {
return new Tuple2<>(new Tuple2<>(random.nextInt(10), categoryCountInfo.getCategoryId()), categoryCountInfo);
}
});
ArrayList<CategoryCountInfo> tuple2s = new ArrayList<>();
// app.groupByKey().collect().forEach(System.out::println);
JavaPairRDD<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableJavaPairRDD = app
.groupByKey()
.map(new Function<Tuple2<Tuple2, Iterable<CategoryCountInfo>>, Tuple2<String, CategoryCountInfo>>() {
@Override
public Tuple2<String, CategoryCountInfo> call(Tuple2<Tuple2, Iterable<CategoryCountInfo>> tuple2IterableTuple2) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(tuple2IterableTuple2._1._2.toString(), 0l, 0l, 0l);
// result.setCategoryId(tuple2IterableTuple2._1.toString());
for (CategoryCountInfo categoryCountInfo : tuple2IterableTuple2._2) {
// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());
}
return new Tuple2<>(result.getCategoryId(), result);
}
})
.groupBy(new Function<Tuple2<String, CategoryCountInfo>, String>() {
@Override
public String call(Tuple2<String, CategoryCountInfo> stringCategoryCountInfoTuple2) throws Exception {
return stringCategoryCountInfoTuple2._1;
}
});
JavaRDD<CategoryCountInfo> map = stringIterableJavaPairRDD
.map(new Function<Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<Tuple2<String, CategoryCountInfo>>> stringIterableTuple2) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);
// result.setCategoryId(tuple2IterableTuple2._1.toString());
for (Tuple2<String, CategoryCountInfo> categoryCountInfo : stringIterableTuple2._2) {
// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo._2.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo._2.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo._2.getPayCount());
}
return result;
}
});
List<CategoryCountInfo> top = map
.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo;
}
}, false, 2)
.top(num);
return top;
}
三、先计算每个分区的TopN,再计算全局TopN
3.1思路
对于每一个key获取每个分区中的TopN。
JavaRDD<CategoryCountInfo> map = app
.mapPartitions(new FlatMapFunction<Iterator<CategoryCountInfo>, CategoryCountInfo>() {
@Override
public Iterator<CategoryCountInfo> call(Iterator<CategoryCountInfo> categoryCountInfoIterator) throws Exception {
TreeMap<String, CategoryCountInfo> stringCategoryCountInfoHashMap = new TreeMap<>();
while (categoryCountInfoIterator.hasNext()) {
CategoryCountInfo next = categoryCountInfoIterator.next();
if (!stringCategoryCountInfoHashMap.containsKey(next.getCategoryId())) {
stringCategoryCountInfoHashMap.put(next.getCategoryId(), next);
} else {
CategoryCountInfo categoryCountInfo = stringCategoryCountInfoHashMap.get(next.getCategoryId());
categoryCountInfo.setClickCount(next.getClickCount() + categoryCountInfo.getClickCount());
categoryCountInfo.setOrderCount(next.getOrderCount() + categoryCountInfo.getOrderCount());
categoryCountInfo.setPayCount(next.getPayCount() + categoryCountInfo.getPayCount());
stringCategoryCountInfoHashMap.put(next.getCategoryId(), categoryCountInfo);
}
}
//这里的返回值中做了局部topn,如果数据量较小的话,或者对topn中每个记录具体值有要求可以不限制。
return stringCategoryCountInfoHashMap.values().stream().limit(first).iterator();
}
})
做全局的数据聚合操作。
.groupBy(new Function<CategoryCountInfo, String>() {
@Override
public String call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo.getCategoryId();
}
})
.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> stringIterableTuple2) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);
//result.setCategoryId(tuple2IterableTuple2._1.toString());
for (CategoryCountInfo categoryCountInfo : stringIterableTuple2._2) {
//CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());
}
return result;
}
});
然后全局排序后取topN。
List<CategoryCountInfo> top = map
.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo;
}
}, false, 2)
.top(second);
3.2代码和结果
public static void main(String[] args) throws ClassNotFoundException {
SparkConf conf = new SparkConf().setMaster("local[*]").setAppName("sparkCore")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.registerKryoClasses(new Class[]{
Class.forName("com.atguigu.CategoryCountInfo"),
Class.forName("com.atguigu.UserVisitAction")
});
;
JavaSparkContext sc = new JavaSparkContext(conf);
JavaRDD<String> stringJavaRDD = sc.textFile("SparkDemo/input/user_visit_action.txt");
JavaRDD categoryCountInfoJavaRDD = dataToRdd(stringJavaRDD);
myGroupByKey(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
myTowStageWithGroup(categoryCountInfoJavaRDD, 5).forEach(System.out::println);
// topN(result,5).forEach(System.out::println);
myPartitionKeyBy(categoryCountInfoJavaRDD, 10, 5).forEach(System.out::println);
// result. collect().forEach(System.out::println);
// 4. 关闭sc
sc.stop();
}
public static List<CategoryCountInfo> myPartitionKeyBy(JavaRDD<CategoryCountInfo> app, int first, int second) {
JavaRDD<CategoryCountInfo> map = app
.mapPartitions(new FlatMapFunction<Iterator<CategoryCountInfo>, CategoryCountInfo>() {
@Override
public Iterator<CategoryCountInfo> call(Iterator<CategoryCountInfo> categoryCountInfoIterator) throws Exception {
TreeMap<String, CategoryCountInfo> stringCategoryCountInfoHashMap = new TreeMap<>();
while (categoryCountInfoIterator.hasNext()) {
CategoryCountInfo next = categoryCountInfoIterator.next();
if (!stringCategoryCountInfoHashMap.containsKey(next.getCategoryId())) {
stringCategoryCountInfoHashMap.put(next.getCategoryId(), next);
} else {
CategoryCountInfo categoryCountInfo = stringCategoryCountInfoHashMap.get(next.getCategoryId());
categoryCountInfo.setClickCount(next.getClickCount() + categoryCountInfo.getClickCount());
categoryCountInfo.setOrderCount(next.getOrderCount() + categoryCountInfo.getOrderCount());
categoryCountInfo.setPayCount(next.getPayCount() + categoryCountInfo.getPayCount());
stringCategoryCountInfoHashMap.put(next.getCategoryId(), categoryCountInfo);
}
}
return stringCategoryCountInfoHashMap.values().stream().limit(first).iterator();
}
})
.groupBy(new Function<CategoryCountInfo, String>() {
@Override
public String call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo.getCategoryId();
}
})
.map(new Function<Tuple2<String, Iterable<CategoryCountInfo>>, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(Tuple2<String, Iterable<CategoryCountInfo>> stringIterableTuple2) throws Exception {
CategoryCountInfo result = new CategoryCountInfo(stringIterableTuple2._1, 0l, 0l, 0l);
// result.setCategoryId(tuple2IterableTuple2._1.toString());
for (CategoryCountInfo categoryCountInfo : stringIterableTuple2._2) {
// CategoryCountInfo x = tuple2IterableTuple2._2.iterator().next();
result.setClickCount(result.getClickCount() + categoryCountInfo.getClickCount());
result.setOrderCount(result.getOrderCount() + categoryCountInfo.getOrderCount());
result.setPayCount(result.getPayCount() + categoryCountInfo.getPayCount());
}
return result;
}
});
List<CategoryCountInfo> top = map
.sortBy(new Function<CategoryCountInfo, CategoryCountInfo>() {
@Override
public CategoryCountInfo call(CategoryCountInfo categoryCountInfo) throws Exception {
return categoryCountInfo;
}
}, false, 2)
.top(second);
return top;
}
总结
方式一:
1. groupByKey会将相同key的所有value全部加载到内存进行处理,当value特别多的时候可能出现OOM异常。
2. groupByKey会将所有的value数据均发送给下一个RDD,性能比较低,因为在实际聚合操作中只需要部分数据。
方式二:
1. 对于聚合类Shuffle操作(groupByKey,reduceByKey等)产生的问题能够很好的解决。
2. 对于非聚合类(join等)产生的问题很难使用该方法解决。
方式三:
1. 解决了方式1实现方式的两个缺点。
2. 都采用了先分区内预聚合,然后进行全局聚合的思想。
实现topN的方法有很多,其实归纳一下可以分成俩种:1. 最容易想到的,也就是没有优化,直接分组聚合取topN。2. 对topN需求进行优化,使用算子优化,或者逻辑优化。把聚合操作分成多步操作,防止oom,提高运行效率。