Spark:sortByKey源码分析

Spark版本:2.4.0

源码位置:org.apache.spark.rdd.OrderedRDDFunctions

应用示例:
object SortByKeyDemo {
  // 打印分区数据分布情况
  def printPartAndElement[T: ClassTag](rdd: RDD[T]): Unit = {
    val parts: Array[Partition] = rdd.partitions
    for (p <- parts) {
      val partIndex = p.index
      val partRdd: RDD[T] = rdd.mapPartitionsWithIndex {
        case (index: Int, value: Iterator[T]) =>
          if (index == partIndex)
            value
          else
            Iterator()
      }
      println("分区id为:" + partIndex)
      partRdd.foreach(println)
    }
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("CombineByKeyDemo")
      .config("spark.master", "local[*]")
      .config("spark.driver.host", "localhost")
      .getOrCreate()
    val sc: SparkContext = spark.sparkContext
    sc.setLogLevel("ERROR")

    val sourceRdd1: RDD[(String, Int)] = sc.parallelize(Seq(("a", 2), ("c", 27), ("f", 3), ("f", 0), ("f", 9),
      ("d", 3), ("d", 2), ("c", 9), ("a", 1), ("b", 2), ("b", 3), ("b", 3), ("c", 2)))
    val sortRdd1: RDD[(String, Int)] = sourceRdd1.sortByKey(false, 2) // 设置2个分区,倒序
    println("打印sortRdd1的元素分区情况:")
    printPartAndElement(sortRdd1)
    val sourceRdd2: RDD[((String, Int), Int)] = sc.parallelize(Seq((("a", 2), 9), (("c", 27), 3), (("f", 3), 4),
      (("f", 0), 0), (("f", 9), 1), (("d", 3), 7),
      (("d", 2), 2), (("c", 9), 12), (("a", 1), 4), (("b", 2), 34), (("b", 3), 1), (("b", 3), 6), (("c", 2), 8)))
    val sortRdd2: RDD[((String, Int), Int)] = sourceRdd2.sortByKey(false, 3) // 设置3个分区,倒序
    println("-------------------------------------------------")
    println("打印sortRdd2的元素分区情况:")
    printPartAndElement(sortRdd2)

    spark.stop()
  }
}

打印结果:
由打印结果发现,sortByKey算子是按照key的大小对数据进行排序
sourceRdd1: RDD[(String, Int)] 按照String的大小进行排序
sourceRdd2: RDD[((String, Int), Int)] 先按照(String,Int)中的String排序,再按照其中的Int排序

打印sortRdd1的元素分区情况:
分区id为:0
(f,3)
(f,0)
(f,9)
(d,3)
(d,2)
分区id为:1
(c,27)
(c,9)
(c,2)
(b,2)
(b,3)
(b,3)
(a,2)
(a,1)
-------------------------------------------------
打印sortRdd2的元素分区情况:
分区id为:0
((f,9),1)
((f,3),4)
((f,0),0)
((d,3),7)
分区id为:1
((d,2),2)
((c,27),3)
((c,9),12)
((c,2),8)
分区id为:2
((b,3),1)
((b,3),6)
((b,2),34)
((a,2),9)
((a,1),4)
源代码如下:
/**
   * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
   * `collect` or `save` on the resulting RDD will return or output an ordered list of records
   * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
   * order of the keys).
   */
  // TODO: this currently doesn't work on P other than Tuple2! 该算子只能对Tuple2数据类型进行操作
  // 正序或倒序参数:ascending: Boolean = true
  // 需要按照key值将数据分布到(numPartitions: Int)个range段中
  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
      : RDD[(K, V)] = self.withScope
  {
    // 对数据进行抽样,根据抽样数据来决定各个range的边界
    val part = new RangePartitioner(numPartitions, self, ascending)
    new ShuffledRDD[K, V, V](self, part)
      .setKeyOrdering(if (ascending) ordering else ordering.reverse)
  }
下面是RangePartitioner提供的数据划分方法:

根据key在rangeBounds中的位置来判断对应的key是出于什么range中,以此划分partition

def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

根据参数决定样本的样本数量,并获取样本数来划分range段的边界

// An array of upper bounds for the first (partitions - 1) partitions
  private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      // Cast to double to avoid overflowing ints or longs
      val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
      // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
      if (numItems == 0L) {
        Array.empty
      } else {
        // If a partition contains much more than the average number of items, we re-sample from it
        // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.length).toFloat
            for (key <- sample) {
              candidates += ((key, weight))
            }
          }
        }
        if (imbalancedPartitions.nonEmpty) {
          // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
          val seed = byteswap32(-rdd.id - 1)
          val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
          val weight = (1.0 / fraction).toFloat
          candidates ++= reSampled.map(x => (x, weight))
        }
        RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
      }
    }
  }

具体采集样本的函数

/**
   * Sketches the input RDD via reservoir sampling on each partition.
   *
   * @param rdd the input RDD to sketch
   * @param sampleSizePerPartition max sample size per partition
   * @return (total number of items, an array of (partitionId, number of items, sample))
   */
  def sketch[K : ClassTag](
      rdd: RDD[K],
      sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
    val shift = rdd.id
    // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
    val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
      val seed = byteswap32(idx ^ (shift << 16))
      val (sample, n) = SamplingUtils.reservoirSampleAndCount(
        iter, sampleSizePerPartition, seed)
      Iterator((idx, n, sample))
    }.collect()
    val numItems = sketched.map(_._2).sum
    (numItems, sketched)
  }

参考自:关于spark中rdd.sortByKey的简单分析

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值