Spark版本:2.4.0
源码位置:org.apache.spark.rdd.OrderedRDDFunctions
应用示例:
object SortByKeyDemo {
// 打印分区数据分布情况
def printPartAndElement[T: ClassTag](rdd: RDD[T]): Unit = {
val parts: Array[Partition] = rdd.partitions
for (p <- parts) {
val partIndex = p.index
val partRdd: RDD[T] = rdd.mapPartitionsWithIndex {
case (index: Int, value: Iterator[T]) =>
if (index == partIndex)
value
else
Iterator()
}
println("分区id为:" + partIndex)
partRdd.foreach(println)
}
}
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.appName("CombineByKeyDemo")
.config("spark.master", "local[*]")
.config("spark.driver.host", "localhost")
.getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("ERROR")
val sourceRdd1: RDD[(String, Int)] = sc.parallelize(Seq(("a", 2), ("c", 27), ("f", 3), ("f", 0), ("f", 9),
("d", 3), ("d", 2), ("c", 9), ("a", 1), ("b", 2), ("b", 3), ("b", 3), ("c", 2)))
val sortRdd1: RDD[(String, Int)] = sourceRdd1.sortByKey(false, 2) // 设置2个分区,倒序
println("打印sortRdd1的元素分区情况:")
printPartAndElement(sortRdd1)
val sourceRdd2: RDD[((String, Int), Int)] = sc.parallelize(Seq((("a", 2), 9), (("c", 27), 3), (("f", 3), 4),
(("f", 0), 0), (("f", 9), 1), (("d", 3), 7),
(("d", 2), 2), (("c", 9), 12), (("a", 1), 4), (("b", 2), 34), (("b", 3), 1), (("b", 3), 6), (("c", 2), 8)))
val sortRdd2: RDD[((String, Int), Int)] = sourceRdd2.sortByKey(false, 3) // 设置3个分区,倒序
println("-------------------------------------------------")
println("打印sortRdd2的元素分区情况:")
printPartAndElement(sortRdd2)
spark.stop()
}
}
打印结果:
由打印结果发现,sortByKey算子是按照key的大小对数据进行排序
sourceRdd1: RDD[(String, Int)] 按照String的大小进行排序
sourceRdd2: RDD[((String, Int), Int)] 先按照(String,Int)中的String排序,再按照其中的Int排序
打印sortRdd1的元素分区情况:
分区id为:0
(f,3)
(f,0)
(f,9)
(d,3)
(d,2)
分区id为:1
(c,27)
(c,9)
(c,2)
(b,2)
(b,3)
(b,3)
(a,2)
(a,1)
-------------------------------------------------
打印sortRdd2的元素分区情况:
分区id为:0
((f,9),1)
((f,3),4)
((f,0),0)
((d,3),7)
分区id为:1
((d,2),2)
((c,27),3)
((c,9),12)
((c,2),8)
分区id为:2
((b,3),1)
((b,3),6)
((b,2),34)
((a,2),9)
((a,1),4)
源代码如下:
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
// TODO: this currently doesn't work on P other than Tuple2! 该算子只能对Tuple2数据类型进行操作
// 正序或倒序参数:ascending: Boolean = true
// 需要按照key值将数据分布到(numPartitions: Int)个range段中
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
: RDD[(K, V)] = self.withScope
{
// 对数据进行抽样,根据抽样数据来决定各个range的边界
val part = new RangePartitioner(numPartitions, self, ascending)
new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
下面是RangePartitioner提供的数据划分方法:
根据key在rangeBounds中的位置来判断对应的key是出于什么range中,以此划分partition
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length <= 128) {
// If we have less than 128 partitions naive search
while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
partition += 1
}
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
}
if (partition > rangeBounds.length) {
partition = rangeBounds.length
}
}
if (ascending) {
partition
} else {
rangeBounds.length - partition
}
}
根据参数决定样本的样本数量,并获取样本数来划分range段的边界
// An array of upper bounds for the first (partitions - 1) partitions
private var rangeBounds: Array[K] = {
if (partitions <= 1) {
Array.empty
} else {
// This is the sample size we need to have roughly balanced output partitions, capped at 1M.
// Cast to double to avoid overflowing ints or longs
val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
// Assume the input partitions are roughly balanced and over-sample a little bit.
val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
if (numItems == 0L) {
Array.empty
} else {
// If a partition contains much more than the average number of items, we re-sample from it
// to ensure that enough items are collected from that partition.
val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
val candidates = ArrayBuffer.empty[(K, Float)]
val imbalancedPartitions = mutable.Set.empty[Int]
sketched.foreach { case (idx, n, sample) =>
if (fraction * n > sampleSizePerPartition) {
imbalancedPartitions += idx
} else {
// The weight is 1 over the sampling probability.
val weight = (n.toDouble / sample.length).toFloat
for (key <- sample) {
candidates += ((key, weight))
}
}
}
if (imbalancedPartitions.nonEmpty) {
// Re-sample imbalanced partitions with the desired sampling probability.
val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
val seed = byteswap32(-rdd.id - 1)
val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
val weight = (1.0 / fraction).toFloat
candidates ++= reSampled.map(x => (x, weight))
}
RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
}
}
}
具体采集样本的函数
/**
* Sketches the input RDD via reservoir sampling on each partition.
*
* @param rdd the input RDD to sketch
* @param sampleSizePerPartition max sample size per partition
* @return (total number of items, an array of (partitionId, number of items, sample))
*/
def sketch[K : ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
val seed = byteswap32(idx ^ (shift << 16))
val (sample, n) = SamplingUtils.reservoirSampleAndCount(
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
val numItems = sketched.map(_._2).sum
(numItems, sketched)
}