top N彻底解秘

  

 

 

本博文内容:

  1、基础Top N算法实战

  2、分组Top N算法实战

  3、排序算法RangePartitioner内幕解密

 

 

 

 

 

 

 

 

 

 

1、基础Top N算法实战

     Top N是排序,Take是直接拿出几个元素,没排序。

 

 

 

  新建

 

1
4
2
5
7
3
2
7
9
1
4
5

 

 

 

 

 

 

 

 

  从源码,来说话,take返回的是数组,不是RDD。而colletc需要的是RDD。

/**
 * Return an array that contains all of the elements in this RDD.
 */
def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

 






 

/**
 * Take the first num elements of the RDD. It works by first scanning one partition, and use the
 * results from that partition to estimate the number of additional partitions needed to satisfy
 * the limit.
 *
 * @note due to complications in the internal implementation, this method will raise
 * an exception if called on an RDD of `Nothing` or `Null`.
 */
def take(num: Int): Array[T] = withScope {
  if (num == 0) {
    new Array[T](0)
  } else {
    val buf = new ArrayBuffer[T]
    val totalParts = this.partitions.length
    var partsScanned = 0
    while (buf.size < num && partsScanned < totalParts) {
      // The number of partitions to try in this iteration. It is ok for this number to be
      // greater than totalParts because we actually cap it at totalParts in runJob.
      var numPartsToTry = 1
      if (partsScanned > 0) {
        // If we didn't find any rows after the previous iteration, quadruple and retry.
        // Otherwise, interpolate the number of partitions we need to try, but overestimate
        // it by 50%. We also cap the estimation in the end.
        if (buf.size == 0) {
          numPartsToTry = partsScanned * 4
        } else {
          // the left side of max is >=1 whenever partsScanned >= 2
          numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
          numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
        }
      }

      val left = num - buf.size
      val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)

      res.foreach(buf ++= _.take(num - buf.size))
      partsScanned += numPartsToTry
    }

    buf.toArray
  }
}

 








   则,所以,代码,如下:

package com.zhouls.spark.cores

import org.apache.spark.{SparkConf, SparkContext}

/**
  * 基础Top N实战
  * Created by Administrator on 2016/10/9.
  */
object TopNBasic {
  def main(args: Array[String]) {
    val conf = new SparkConf()
    conf.setAppName("Top N Basically!").setMaster("local")
    val sc = new SparkContext(conf)
    val lines = sc.textFile("D://SoftWare//spark-1.5.2-bin-hadoop2.6//basicTopN.txt")
    val pairs = lines.map(line =>(line.toInt,line)) //生成key-value键值对,方便sortByKey进行排序
    val sortedPairs = pairs.sortByKey(false) //降序排序
    val sortedData = sortedPairs.map(pair => pair._2)  //只要是改变每一行列的数据,一般都是用map操作。过滤出排序后的内容本身
    val top5 = sortedData.take(5)  //获取排名前5位的元素内容
    top5.foreach(println)
  }
}

 

 

 

好的,这里,学个新知识点。

setLogLevel
看源码

/** Control our logLevel. This overrides any user-defined log settings.
 * @param logLevel The desired log level as a string.
 * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
 */
def setLogLevel(logLevel: String) {
  val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN")
  if (!validLevels.contains(logLevel)) {
    throw new IllegalArgumentException(
      s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}")
  }
  Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel))
}

 




 
 
setLogLevel("ALL")

对应的打印输出信息,

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7533 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program 
artitions
d size 1814.0 B, free 976.2 MB)
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] received message AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true) from Actor[akka://sparkDriver/temp/$g]
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: Received RPC message: AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true)
16/10/09 09:15:38 INFO BlockManagerInfo: Added broadcast_2_piece0 in memory on localhost:52833 (size: 1814.0 B, free: 976.3 MB)
16/10/09 09:15:38 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] handled message (3.09051 ms) AkkaMessage(UpdateBlockInfo(BlockManagerId(driver, localhost, 52833),broadcast_2_piece0,StorageLevel(false, true, false, false, 1),1814,0,0),true) from Actor[akka://sparkDriver/temp/$g]
16/10/09 09:15:38 DEBUG BlockManagerMaster: Updated info of block broadcast_2_piece0
16/10/09 09:15:38 DEBUG BlockManager: Told master about block broadcast_2_piece0
16/10/09 09:15:38 DEBUG BlockManager: Put block broadcast_2_piece0 locally took 8 ms
16/10/09 09:15:38 DEBUG BlockManager: Putting block broadcast_2_piece0 without replication took 9 ms
16/10/09 09:15:38 INFO SparkContext: Created broadcast 2 from broadcast at DAGScheduler.scala:861
 bytes)
16/10/09 09:15:39 TRACE DAGScheduler: failed: Set()
16/10/09 09:15:39 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 1.022280 s
9
7
7
5
5
16/10/09 09:15:39 INFO SparkContext: Invoking stop() from shutdown hook
age (5.094032 ms) AkkaMessage(StopCoordinator,false) from Actor[akka://sparkDriver/deadLetters]
16/10/09 09:15:39 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-3656d24c-bfdb-4def-b751-8d7fc84150cb

Process finished with exit code 0

 

 

setLogLevel("DEBUG")

对应的,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7534 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\cldrdata.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\dnsns.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\jaccess.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\jfxrt.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\localedata.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\nashorn.jar;C:\Program fun$28
16/10/09 09:18:05 DEBUG AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1: [actor] handled message (2.022709 ms) AkkaMessage(StatusUpdate(1,FINISHED,java.nio.HeapByteBuffer[pos=0 lim=1185 cap=1185]),false) from Actor[akka://sparkDriver/deadLetters]
16/10/09 09:18:05 INFO TaskSetManager: Finished task 0.0 in stage 1.0 (TID 1) in 153 ms on localhost (1/1)
16/10/09 09:18:05 INFO TaskSchedulerImpl: Removed TaskSet 1.0, whose tasks have all completed, from pool
16/10/09 09:18:05 INFO DAGScheduler: ResultStage 1 (take at TopNBasic.scala:20) finished in 0.163 s
16/10/09 09:18:05 DEBUG DAGScheduler: After removal of stage 1, remaining stages = 1
16/10/09 09:18:05 DEBUG DAGScheduler: After removal of stage 0, remaining stages = 0
16/10/09 09:18:05 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 0.985550 s
9
7
7
5
5
16/10/09 09:18:05 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:18:05 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:18:05 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-c9f238f3-9210-4f3a-a248-11f6f610163e

Process finished with exit code 0

 

 

setLogLevel("ERROR")

对应地,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7535 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:18:43 INFO BlockManagerMasterEndpoint: Registering block manager localhost:52966 with 976.3 MB RAM, BlockManagerId(driver, localhost, 52966)
16/10/09 09:18:43 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5
16/10/09 09:18:50 WARN QueuedThreadPool: 3 threads could not be stopped

Process finished with exit code 0




setLogLevel("FATAL")

对应地,打印输出信息, 是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7536 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:20:17 INFO BlockManagerMasterEndpoint: Registering block manager localhost:53014 with 976.3 MB RAM, BlockManagerId(driver, localhost, 53014)
16/10/09 09:20:17 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5

Process finished with exit code 0

 

 

 

setLogLevel("INFO")

对应地,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7537 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:21:17 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 1.085930 s
9
7
7
5
5
16/10/09 09:21:17 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:21:17 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:21:17 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-de03b369-fec4-4785-abec-563c502d0bd7

Process finished with exit code 0

 

 

 

setLogLevel("OFF")

对应地,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7538 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program 
16/10/09 09:22:10 INFO BlockManagerMasterEndpoint: Registering block manager localhost:53098 with 976.3 MB RAM, BlockManagerId(driver, localhost, 53098)
16/10/09 09:22:10 INFO BlockManagerMaster: Registered BlockManager
9
7
7
5
5

Process finished with exit code 0

 

 

setLogLevel("TRACE")

对应地,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7539 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\deploy.jar;C:\Program Files\Java\jdk1.8.0_66\jre\lib\ext\access-bridge-64.jar;C:\Program 
16/10/09 09:23:15 TRACE DAGScheduler: running: Set()
16/10/09 09:23:15 TRACE DAGScheduler: waiting: Set()
16/10/09 09:23:15 TRACE DAGScheduler: failed: Set()
16/10/09 09:23:15 INFO DAGScheduler: Job 0 finished: take at TopNBasic.scala:20, took 0.985096 s
9
7
7
5
5
16/10/09 09:23:15 INFO SparkContext: Invoking stop() from shutdown hook
16/10/09 09:23:15 INFO SparkUI: Stopped Spark web UI at http://192.168.56.1:4040
16/10/09 09:23:15 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-d3604805-b6e2-4873-a8aa-10cabda4f329

Process finished with exit code 0

 

 

 

setLogLevel("WARN")

对应地,打印输出信息,是

"C:\Program Files\Java\jdk1.8.0_66\bin\java" -Didea.launcher.port=7532 "-Didea.launcher.bin.path=D:\SoftWare\IntelliJ IDEA\IntelliJ IDEA Community Edition 2016.1.4\bin" -Dfile.encoding=UTF-8 -classpath "C:\Program Files\Java\jdk1.8.0_66\jre\lib\charsets.jar;C:\Program fe80:0:0:0:0:5efe:c0a8:bf02%net11, but we couldn't find any external IP address!
9
7
7
5
5

Process finished with exit code 0

总结:基础Top N算法实战至此。





2、分组Top N算法实战
先从Java语言,来实战

写代码

Spark 100
Hadoop 65
Spark 99
Hadoop 61
Spark 195
Hadoop 60
Spark 98
Hadoop 69
Spark 91
Hadoop 64
Spark 89
Hadoop 98
Spark 88 
Hadoop 99
Spark 68
Hadoop 60
Spark 79
Hadoop 97
Spark 69
Hadoop 96

 

 

 

 

package com.zhouls.spark.SparkApps.cores;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;

public class TopNGroup {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("TopNGroup").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf); //其底层实际上就是Scala的SparkContext

JavaRDD<String> lines = sc.textFile("D://SoftWare//spark-1.5.2-bin-hadoop2.6//groupTopN.txt");


JavaPairRDD<String, Integer> pairs = lines.mapToPair(new PairFunction<String, String, Integer>() {
private static final long serialVersionUID =1L ;
@Override
public Tuple2<String, Integer> call(String line) throws Exception {

String[] splitedLine =line.split(" ");
System.out.println(splitedLine[0]);
return new Tuple2<String,Integer>(splitedLine[0],Integer.valueOf(splitedLine[1]));
}
});

JavaPairRDD<String, Iterable<Integer>> groupedPairs =pairs.groupByKey();

JavaPairRDD<String, Iterable<Integer>> top5=groupedPairs.mapToPair(new
PairFunction<Tuple2<String,Iterable<Integer>>, String, Iterable<Integer>>() {
/**
* 
*/
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String, Iterable<Integer>> call(Tuple2<String, Iterable<Integer>> groupedData)
throws Exception {
// TODO Auto-generated method stub

Integer[] top5=new Integer[5];
String groupedKey= groupedData._1;
Iterator<Integer> groupedValue = groupedData._2.iterator();

while(groupedValue.hasNext()){
Integer value = groupedValue.next();

for (int i =0; i<5; i++){
if (top5[i] ==null) {
top5[i] = value ;
break;
} else if (value > top5[i]) {
for (int j = 4; j > i; j--){
top5[j] = top5[j-1];
}
top5[i]=value;
break;
} 

}

} 



return new Tuple2<String, Iterable<Integer>>(groupedKey,Arrays.asList(top5));
}
}) ;

//打印分组后的Top N 
top5.foreach(new VoidFunction<Tuple2<String,Iterable<Integer>>>() {
@Override
public void call(Tuple2<String, Iterable<Integer>> topped) throws Exception {

System.out.println("Group key :"+ topped._1);//获取Group key
Iterator<Integer> toppedValue = topped._2.iterator(); //获取Group Value
while (toppedValue.hasNext()){ //具体打印出每组的Top N
Integer value =toppedValue.next();
System.out.println(value); 
}
System.out.println("******************************************************88");
}
});



}
}

 

 

 

 

 

 感谢下面的博主:

http://www.it610.com/article/5193051.htm

 

 

 

 

 

 

  若是groupTopN.txt的内容是:

Spark 100
Hadoop 62
Flink 77
Kafka 91
Hadoop 93
Spark 78
Hadoop 69
Spark 98
Hadoop 62
Spark 99
Hadoop 61
Spark 70
Hadoop 75
Spark 88
Hadoop 68
Spark 90
Hadoop 61

 

 

 

  则,对应地是,

 

 

 

 分组Top N算法实战的总结:

分组TOPN排序
  1.读入每行数据 JavaRDD<String> lines
  2、生成pairs K,V键值对  JavaPairRDD<String, Integer> pairs
       输入一行的数据
       输出的KEY值是名称,Value是分数 Iterable;
  3、groupByKey按名称进行分组: JavaPairRDD<String, Iterable<Integer>>  groupedPairs =pairs.groupByKey();
  4、分组以后进行排序
    输入groupdata,其中 KEY是名称的组名,VALUE是分数的集合
    输出 KEY:分组排序以后的组名,VALUE:是排序以后的分数的集合 取5个值

    JavaPairRDD<String, Iterable<Integer>> top5=groupedPairs.mapToPair(new

  PairFunction<Tuple2<String,Iterable<Integer>>, String, Iterable<Integer>>() {

 

 

 

 

 

 

3、排序算法RangePartitioner内幕解密

/**
 * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
 * `collect` or `save` on the resulting RDD will return or output an ordered list of records
 * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
 * order of the keys).
 */
// TODO: this currently doesn't work on P other than Tuple2!
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
    : RDD[(K, V)] = self.withScope
{
  val part = new RangePartitioner(numPartitions, self, ascending)
  new ShuffledRDD[K, V, V](self, part)
    .setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
RangePartitioner主要是依赖的RDD的数据划分成不同的范围,关键的地方是不同的范围是有序的。
RangePartitioner除了是结果有序的基石以外,最为重要的是尽量保证每个Partition中的数据量是均匀的!

 

 

 

 
 

 Google的面试题:如何在一个不确定数据规模的范围内,进行排序。

排序的几个内容:

1、二分算法,将key值放入对于的分区

   在未接触二分查找算法时,最通用的一种做法是,对数组进行遍历,跟每个元素进行比较,其时间为O(n).但二分查找算法则

更优,因为其查找时间为O(lgn),譬如数组{1, 2, 3, 4, 5, 6, 7, 8, 9},查找元素6,用二分查找的算法执行的话,

其顺序为:
    1.第一步查找中间元素,即5,由于5<6,则6必然在5之后的数组元素中,那么就在{6, 7, 8, 9}中查找,
    2.寻找{6, 7, 8, 9}的中位数,为7,7>6,则6应该在7左边的数组元素中,那么只剩下6,即找到了。

 

2、水桶抽样算法,(适合数据规模是特别大,内存容纳不下时的情况)以下乘以3的原因
   乘3的原因是RDD的分区可能有数据倾斜,sampleSize是期望的样本大小,但是某些分区的数据量可能少于

sampleSize/PartitionNumber,乘以3后期望其他的分区可以多采样点数据,使得总的采样量达到或超过sampleSize。
     // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)

      // Assume the input partitions are roughly balanced and a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt

 

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.{ClassTag, classTag}
import scala.util.hashing.byteswap32

import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.{CollectionsUtils, Utils}
import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils}

/**
 * An object that defines how the elements in a key-value pair RDD are partitioned by key.
 * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
 */
abstract class Partitioner extends Serializable {
  def numPartitions: Int
  def getPartition(key: Any): Int
}

object Partitioner {
  /**
   * Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
   *
   * If any of the RDDs already has a partitioner, choose that one.
   *
   * Otherwise, we use a default HashPartitioner. For the number of partitions, if
   * spark.default.parallelism is set, then we'll use the value from SparkContext
   * defaultParallelism, otherwise we'll use the max number of upstream partitions.
   *
   * Unless spark.default.parallelism is set, the number of partitions will be the
   * same as the number of partitions in the largest upstream RDD, as this should
   * be least likely to cause out-of-memory errors.
   *
   * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
   */
  def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
    val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
    for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
      return r.partitioner.get
    }
    if (rdd.context.conf.contains("spark.default.parallelism")) {
      new HashPartitioner(rdd.context.defaultParallelism)
    } else {
      new HashPartitioner(bySize.head.partitions.size)
    }
  }
}

/**
 * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
 * Java's `Object.hashCode`.
 *
 * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
 * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
 * produce an unexpected or incorrect result.
 */
class HashPartitioner(partitions: Int) extends Partitioner {
  require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

  def numPartitions: Int = partitions

  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner =>
      h.numPartitions == numPartitions
    case _ =>
      false
  }

  override def hashCode: Int = numPartitions
}

/**
 * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
 * equal ranges. The ranges are determined by sampling the content of the RDD passed in.
 *
 * Note that the actual number of partitions created by the RangePartitioner might not be the same
 * as the `partitions` parameter, in the case where the number of sampled records is less than
 * the value of `partitions`.
 */
class RangePartitioner[K : Ordering : ClassTag, V](
    @transient partitions: Int,
    @transient rdd: RDD[_ <: Product2[K, V]],
    private var ascending: Boolean = true)
  extends Partitioner {

  // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
  require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")

  private var ordering = implicitly[Ordering[K]]

  // An array of upper bounds for the first (partitions - 1) partitions
  private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)
      // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
      if (numItems == 0L) {
        Array.empty
      } else {
        // If a partition contains much more than the average number of items, we re-sample from it
        // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.size).toFloat
            for (key <- sample) {
              candidates += ((key, weight))
            }
          }
        }
        if (imbalancedPartitions.nonEmpty) {
          // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
          val seed = byteswap32(-rdd.id - 1)
          val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
          val weight = (1.0 / fraction).toFloat
          candidates ++= reSampled.map(x => (x, weight))
        }
        RangePartitioner.determineBounds(candidates, partitions)
      }
    }
  }

  def numPartitions: Int = rangeBounds.length + 1

  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

  override def equals(other: Any): Boolean = other match {
    case r: RangePartitioner[_, _] =>
      r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
    case _ =>
      false
  }

  override def hashCode(): Int = {
    val prime = 31
    var result = 1
    var i = 0
    while (i < rangeBounds.length) {
      result = prime * result + rangeBounds(i).hashCode
      i += 1
    }
    result = prime * result + ascending.hashCode
    result
  }

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => out.defaultWriteObject()
      case _ =>
        out.writeBoolean(ascending)
        out.writeObject(ordering)
        out.writeObject(binarySearch)

        val ser = sfactory.newInstance()
        Utils.serializeViaNestedStream(out, ser) { stream =>
          stream.writeObject(scala.reflect.classTag[Array[K]])
          stream.writeObject(rangeBounds)
        }
    }
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => in.defaultReadObject()
      case _ =>
        ascending = in.readBoolean()
        ordering = in.readObject().asInstanceOf[Ordering[K]]
        binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]

        val ser = sfactory.newInstance()
        Utils.deserializeViaNestedStream(in, ser) { ds =>
          implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
          rangeBounds = ds.readObject[Array[K]]()
        }
    }
  }
}

private[spark] object RangePartitioner {

  /**
   * Sketches the input RDD via reservoir sampling on each partition.
   *
   * @param rdd the input RDD to sketch
   * @param sampleSizePerPartition max sample size per partition
   * @return (total number of items, an array of (partitionId, number of items, sample))
   */
  def sketch[K : ClassTag](
      rdd: RDD[K],
      sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
    val shift = rdd.id
    // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
    val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
      val seed = byteswap32(idx ^ (shift << 16))
      val (sample, n) = SamplingUtils.reservoirSampleAndCount(
        iter, sampleSizePerPartition, seed)
      Iterator((idx, n, sample))
    }.collect()
    val numItems = sketched.map(_._2.toLong).sum
    (numItems, sketched)
  }

  /**
   * Determines the bounds for range partitioning from candidates with weights indicating how many
   * items each represents. Usually this is 1 over the probability used to sample this candidate.
   *
   * @param candidates unordered candidates with weights
   * @param partitions number of partitions
   * @return selected bounds
   */
  def determineBounds[K : Ordering : ClassTag](
      candidates: ArrayBuffer[(K, Float)],
      partitions: Int): Array[K] = {
    val ordering = implicitly[Ordering[K]]
    val ordered = candidates.sortBy(_._1)
    val numCandidates = ordered.size
    val sumWeights = ordered.map(_._2.toDouble).sum
    val step = sumWeights / partitions
    var cumWeight = 0.0
    var target = step
    val bounds = ArrayBuffer.empty[K]
    var i = 0
    var j = 0
    var previousBound = Option.empty[K]
    while ((i < numCandidates) && (j < partitions - 1)) {
      val (key, weight) = ordered(i)
      cumWeight += weight
      if (cumWeight > target) {
        // Skip duplicate values.
        if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) {
          bounds += key
          target += step
          j += 1
          previousBound = Some(key)
        }
      }
      i += 1
    }
    bounds.toArray
  }
}

 




 如,源码中的

水桶抽样算法,(适合数据规模是特别大,内存容纳不下时的情况)以下乘以3的原因
   乘3的原因是RDD的分区可能有数据倾斜,sampleSize是期望的样本大小,但是某些分区的数据量可能少于

sampleSize/PartitionNumber,乘以3后期望其他的分区可以多采样点数据,使得总的采样量达到或超过sampleSize。
     // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)

      // Assume the input partitions are roughly balanced and a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt

 

 

 

 

 

 

 

sketch源码

/**
 * Sketches the input RDD via reservoir sampling on each partition.
 *
 * @param rdd the input RDD to sketch
 * @param sampleSizePerPartition max sample size per partition
 * @return (total number of items, an array of (partitionId, number of items, sample))
 */
def sketch[K : ClassTag](
    rdd: RDD[K],
    sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
  val shift = rdd.id
  // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
  val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
    val seed = byteswap32(idx ^ (shift << 16))
    val (sample, n) = SamplingUtils.reservoirSampleAndCount(
      iter, sampleSizePerPartition, seed)
    Iterator((idx, n, sample))
  }.collect()
  val numItems = sketched.map(_._2.toLong).sum
  (numItems, sketched)
}

 




 

   reservoirSampleAndCount源码

 

/**
 * Reservoir sampling implementation that also returns the input size.
 *
 * @param input input size
 * @param k reservoir size
 * @param seed random seed
 * @return (samples, input size)
 */
def reservoirSampleAndCount[T: ClassTag](
    input: Iterator[T],
    k: Int,
    seed: Long = Random.nextLong())
  : (Array[T], Int) = {
  val reservoir = new Array[T](k)
  // Put the first k elements in the reservoir.
  var i = 0
  while (i < k && input.hasNext) {
    val item = input.next()
    reservoir(i) = item
    i += 1
  }

  // If we have consumed all the elements, return them. Otherwise do the replacement.
  if (i < k) {
    // If input size < k, trim the array to return only an array of input size.
    val trimReservoir = new Array[T](i)
    System.arraycopy(reservoir, 0, trimReservoir, 0, i)
    (trimReservoir, i)
  } else {
    // If input size > k, continue the sampling process.
    val rand = new XORShiftRandom(seed)
    while (input.hasNext) {
      val item = input.next()
      val replacementIndex = rand.nextInt(i)
      if (replacementIndex < k) {
        reservoir(replacementIndex) = item
      }
      i += 1
    }
    (reservoir, i)
  }
}

 




 

 

 
 
getPartition源码

def getPartition(key: Any): Int = {
  val k = key.asInstanceOf[K]
  var partition = 0
  if (rangeBounds.length <= 128) {
    // If we have less than 128 partitions naive search
    while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
      partition += 1
    }
  } else {
    // Determine which binary search method to use only once.
    partition = binarySearch(rangeBounds, k)
    // binarySearch either returns the match location or -[insertion point]-1
    if (partition < 0) {
      partition = -partition-1
    }
    if (partition > rangeBounds.length) {
      partition = rangeBounds.length
    }
  }
  if (ascending) {
    partition
  } else {
    rangeBounds.length - partition
  }
}

 



 
 

 

    二分算法,将key值放入对于的分区

   在未接触二分查找算法时,最通用的一种做法是,对数组进行遍历,跟每个元素进行比较,其时间为O(n).但二分查找算法则

更优,因为其查找时间为O(lgn),譬如数组{1, 2, 3, 4, 5, 6, 7, 8, 9},查找元素6,用二分查找的算法执行的话,

   其顺序为:
    1.第一步查找中间元素,即5,由于5<6,则6必然在5之后的数组元素中,那么就在{6, 7, 8, 9}中查找,
    2.寻找{6, 7, 8, 9}的中位数,为7,7>6,则6应该在7左边的数组元素中,那么只剩下6,即找到了。




二分算法,确定,具体key属于哪个分区,然后,之后,就可以用RangePartitioner了。
更多,见
http://www.it610.com/article/5193051.htm.
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值