HashParititioner
聚合算子默认分区器 通过hash值分区
RandPartitioner
范围分区器
排序类算子默认分区器
使用水塘抽样算法(抽样概率相同),对数据进行抽样来划分数据
边界数组:数组长度由分区数决定,通过水塘抽样计算出数据切分的范围 存放在边界数组中
源码:
/** * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions private var rangeBounds: Array[K] = { if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. val sampleSize = math.min(20.0 * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) if (numItems == 0L) { Array.empty } else { // If a partition contains much more than the average number of items, we re-sample from it // to ensure that enough items are collected from that partition. val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0) val candidates = ArrayBuffer.empty[(K, Float)] val imbalancedPartitions = mutable.Set.empty[Int] sketched.foreach { case (idx, n, sample) => if (fraction * n > sampleSizePerPartition) { imbalancedPartitions += idx } else { // The weight is 1 over the sampling probability. val weight = (n.toDouble / sample.length).toFloat for (key <- sample) { candidates += ((key, weight)) } } } if (imbalancedPartitions.nonEmpty) { // Re-sample imbalanced partitions with the desired sampling probability. val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains) val seed = byteswap32(-rdd.id - 1) val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect() val weight = (1.0 / fraction).toFloat candidates ++= reSampled.map(x => (x, weight)) } RangePartitioner.determineBounds(candidates, partitions) } } } def numPartitions: Int = rangeBounds.length + 1 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] var partition = 0 if (rangeBounds.length <= 128) { // If we have less than 128 partitions naive search while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) { partition += 1 } } else { // Determine which binary search method to use only once. partition = binarySearch(rangeBounds, k) // binarySearch either returns the match location or -[insertion point]-1 if (partition < 0) { partition = -partition-1 } if (partition > rangeBounds.length) { partition = rangeBounds.length } } if (ascending) { partition } else { rangeBounds.length - partition } } override def equals(other: Any): Boolean = other match { case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false } override def hashCode(): Int = { val prime = 31 var result = 1 var i = 0 while (i < rangeBounds.length) { result = prime * result + rangeBounds(i).hashCode i += 1 } result = prime * result + ascending.hashCode result } @throws(classOf[IOException]) private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => out.defaultWriteObject() case _ => out.writeBoolean(ascending) out.writeObject(ordering) out.writeObject(binarySearch) val ser = sfactory.newInstance() Utils.serializeViaNestedStream(out, ser) { stream => stream.writeObject(scala.reflect.classTag[Array[K]]) stream.writeObject(rangeBounds) } } } @throws(classOf[IOException]) private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() case _ => ascending = in.readBoolean() ordering = in.readObject().asInstanceOf[Ordering[K]] binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int] val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => implicit val classTag = ds.readObject[ClassTag[Array[K]]]() rangeBounds = ds.readObject[Array[K]]() } } } }
自定义分区器
实现Partitioner类
//自定义Hash分区器
class MyPartitioner(num: Int) extends Partitioner {
assert(num > 0)
override def numPartitions: Int = num
override def getPartition(key: Any): Int = key match {
case null => 0
case _ => key.hashCode().abs % num
}
}