Spark分区器:HashParititioner、RandPartitioner、自定义分区器

HashParititioner

聚合算子默认分区器 通过hash值分区

RandPartitioner

范围分区器

排序类算子默认分区器

使用水塘抽样算法(抽样概率相同),对数据进行抽样来划分数据

边界数组:数组长度由分区数决定,通过水塘抽样计算出数据切分的范围 存放在边界数组中

源码:

/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
*
* @note The actual number of partitions created by the RangePartitioner might not be the same
* as the `partitions` parameter, in the case where the number of sampled records is less than
* the value of `partitions`.
*/
class RangePartitioner[K : Ordering : ClassTag, V](
   partitions: Int,
   rdd: RDD[_ <: Product2[K, V]],
   private var ascending: Boolean = true)
 extends Partitioner {

 // We allow partitions = 0, which happens when sorting an empty RDD under the default settings.
 require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")

 private var ordering = implicitly[Ordering[K]]

 // An array of upper bounds for the first (partitions - 1) partitions
 private var rangeBounds: Array[K] = {
   if (partitions <= 1) {
     Array.empty
   } else {
     // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
     val sampleSize = math.min(20.0 * partitions, 1e6)
     // Assume the input partitions are roughly balanced and over-sample a little bit.
     val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
     val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
     if (numItems == 0L) {
       Array.empty
     } else {
       // If a partition contains much more than the average number of items, we re-sample from it
       // to ensure that enough items are collected from that partition.
       val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
       val candidates = ArrayBuffer.empty[(K, Float)]
       val imbalancedPartitions = mutable.Set.empty[Int]
       sketched.foreach { case (idx, n, sample) =>
         if (fraction * n > sampleSizePerPartition) {
           imbalancedPartitions += idx
         } else {
           // The weight is 1 over the sampling probability.
           val weight = (n.toDouble / sample.length).toFloat
           for (key <- sample) {
             candidates += ((key, weight))
           }
         }
       }
       if (imbalancedPartitions.nonEmpty) {
         // Re-sample imbalanced partitions with the desired sampling probability.
         val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
         val seed = byteswap32(-rdd.id - 1)
         val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
         val weight = (1.0 / fraction).toFloat
         candidates ++= reSampled.map(x => (x, weight))
       }
       RangePartitioner.determineBounds(candidates, partitions)
     }
   }
 }

 def numPartitions: Int = rangeBounds.length + 1

 private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

 def getPartition(key: Any): Int = {
   val k = key.asInstanceOf[K]
   var partition = 0
   if (rangeBounds.length <= 128) {
     // If we have less than 128 partitions naive search
     while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
       partition += 1
     }
   } else {
     // Determine which binary search method to use only once.
     partition = binarySearch(rangeBounds, k)
     // binarySearch either returns the match location or -[insertion point]-1
     if (partition < 0) {
       partition = -partition-1
     }
     if (partition > rangeBounds.length) {
       partition = rangeBounds.length
     }
   }
   if (ascending) {
     partition
   } else {
     rangeBounds.length - partition
   }
 }

 override def equals(other: Any): Boolean = other match {
   case r: RangePartitioner[_, _] =>
     r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
   case _ =>
     false
 }

 override def hashCode(): Int = {
   val prime = 31
   var result = 1
   var i = 0
   while (i < rangeBounds.length) {
     result = prime * result + rangeBounds(i).hashCode
     i += 1
   }
   result = prime * result + ascending.hashCode
   result
 }

 @throws(classOf[IOException])
 private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
   val sfactory = SparkEnv.get.serializer
   sfactory match {
     case js: JavaSerializer => out.defaultWriteObject()
     case _ =>
       out.writeBoolean(ascending)
       out.writeObject(ordering)
       out.writeObject(binarySearch)

       val ser = sfactory.newInstance()
       Utils.serializeViaNestedStream(out, ser) { stream =>
         stream.writeObject(scala.reflect.classTag[Array[K]])
         stream.writeObject(rangeBounds)
       }
   }
 }

 @throws(classOf[IOException])
 private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
   val sfactory = SparkEnv.get.serializer
   sfactory match {
     case js: JavaSerializer => in.defaultReadObject()
     case _ =>
       ascending = in.readBoolean()
       ordering = in.readObject().asInstanceOf[Ordering[K]]
       binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]

       val ser = sfactory.newInstance()
       Utils.deserializeViaNestedStream(in, ser) { ds =>
         implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
         rangeBounds = ds.readObject[Array[K]]()
       }
   }
 }
}

自定义分区器

实现Partitioner类


//自定义Hash分区器
class MyPartitioner(num: Int) extends Partitioner {

  assert(num > 0)

  override def numPartitions: Int = num

  override def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => key.hashCode().abs % num
  }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值