1.案例:对每个班级内的学生成绩,取出前3名
1>文件部分数据:
2>代码:
Java版:
import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
public class GroupTopN {
@SuppressWarnings("unused")
public static void main(String[] args) {
SparkConf conf = new SparkConf()
.setAppName("GroupTopN")
.setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
//读取数据
JavaRDD<String> lines = sc.textFile("G://SparkDevel//test//wordCount//data//score.txt");
//创建(class,score)元组
JavaPairRDD<String, Integer> classScore = lines.mapToPair(
new PairFunction<String, String, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Integer> call(String v1) throws Exception {
String[] temp = v1.split(" ");
return new Tuple2<String, Integer>(temp[0], Integer.valueOf(temp[1]));
}
});
//按key=class来分类
JavaPairRDD<String, Iterable<Integer>> sortedGroupPairs = classScore.groupByKey();
//获取每个class的topN的分数,本例中N=3
JavaPairRDD<String, Iterable<Integer>> sortedGroupTop3= sortedGroupPairs.mapToPair(
new PairFunction<Tuple2<String,Iterable<Integer>>, String, Iterable<Integer>>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Iterable<Integer>> call(Tuple2<String, Iterable<Integer>> v1)
throws Exception {
Integer[] top3 = new Integer[3];
String className = v1._1;
Iterator<Integer> scores = v1._2.iterator();
while(scores.hasNext()) {
Integer score = scores.next();
for(int i = 0; i < 3; i++) {
if(top3[i] == null) {
top3[i] = score;
break;
} else if(score > top3[i]) {
for(int j = 2; j > i; j--) {
top3[j] = top3[j - 1];
}
top3[i] = score;
break;
}
}
}
return new Tuple2<String, Iterable<Integer>>(v1._1, Arrays.asList(top3));
}
});
//打印已经分好类排好序的数据
sortedGroupTop3.foreach(new VoidFunction<Tuple2<String,Iterable<Integer>>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2<String, Iterable<Integer>> v1) throws Exception {
System.out.println("class: " + v1._1);
Iterator<Integer> scoreIterator = v1._2.iterator();
while(scoreIterator.hasNext()) {
Integer score = scoreIterator.next();
System.out.println(score);
}
System.out.println("=======================================");
}
});
sc.close();
}
}
Scala版:
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import breeze.linalg.split
import scala.util.control.Breaks._
import sun.font.TrueTypeFont
object GroupTopN {
def main(args: Array[String]){
val conf = new SparkConf()
.setAppName("GroupTopN")
.setMaster("local")
val sc = new SparkContext(conf)
val lines = sc.textFile("G://SparkDevel//test//wordCount//data//score.txt", 4);
val classScores = lines.map(line => (line.split(" ")(0), line.split(" ")(1).toInt))
val sortedGroup = classScores.groupByKey()
val GroupTop3 = sortedGroup.map{
v1 =>
val N =3
val topN = new Array[Int](N)
val scores = v1._2.iterator
while (scores.hasNext) {
val score = scores.next()
breakable{
for(i <- 0 until(N)){
if(topN(i) == 0){
topN(i) = score
break
}else if (score > topN(i)) {
for(j <- (i+1 until N).reverse){
topN(j) = topN(j-1)
}
topN(i) = score
break
}
}
}
}
(v1._1, topN)
}
GroupTop3.foreach{
v1 =>
println("class:" + v1._1)
for(score <- v1._2){
println(score)
}
println("===============")
}
}
}
结果: