org.apache.spark.Partitioner
定义了key-value
类型的的RDD如何根据key值做partition,getPartition
方法将任意类型的key值对应到Int类型的Partition ID
// 抽象基类定义了numPartitions参数及getPartition方法
abstract class Partitioner extends Serializable {
def numPartitions: Int
def getPartition(key: Any): Int
}
// 默认使用HashPartitioner,numPartitions通过spark.default.parallelism参数配置
object Partitioner {
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
for (r <- bySize if r.partitioner.isDefined) {
return r.partitioner.get
}
if (rdd.context.conf.contains("spark.default.parallelism")) {
new HashPartitioner(rdd.context.defaultParallelism)
} else {
new HashPartitioner(bySize.head.partitions.size)
}
}
}
Spark提供了两种子类实现
// 子类HashPartitioner将key的hash值对numPartitions取余
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions: Int = partitions
def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
// 子类RangePartitioner的目的是将任意RDD均匀分布到若干个partition中,使用的是水塘采样法(Reservoir Sampling)
class RangePartitioner[K : Ordering : ClassTag, V](
@transient partitions: Int,
@transient rdd: RDD[_ <: Product2[K, V]],
private var ascending: Boolean = true)
extends Partitioner