我们都知道Spark内部提供了HashPartitioner
和RangePartitioner
两种分区策略,这两种分区策略在很多情况下都适合我们的场景。但是有些情况下,Spark内部不能符合咱们的需求,这时候我们就可以自定义分区策略。为此,Spark提供了相应的接口,我们只需要扩展Partitioner
抽象类,然后实现里面的三个方法:
// 这个方法需要返回你想要创建分区的个数
override def numPartitions: Int = numParts
// 这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1
override def getPartition(key: Any): Int
// Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样
override def equals(other: Any): Boolean
实现案例
假如我们想把来自同一个域名的URL放到一台节点上,比如http://www.baidu.com/aa/120
和http://www.baidu.com/bb/110
,如果你使用HashPartitioner
,这两个URL的Hash值可能不一样,这就使得这两个URL被放到不同的节点上。所以这种情况下我们就需要自定义我们的分区策略,可以如下实现:
import org.apache.spark.{Partitioner, SparkConf, SparkContext, TaskContext}
class SelfPartitioner(numParts:Int) extends Partitioner {
//覆盖分区数
override def numPartitions: Int = numParts
//覆盖分区号获取函数
override def getPartition(key: Any): Int = {
val domain = new java.net.URL(key.toString).getHost()
val code = (domain.hashCode % numPartitions)
if (code < 0) {
code + numPartitions
} else {
code
}
}
override def equals(other: Any): Boolean = other match {
case aa: SelfPartitioner =>
aa.numPartitions == numPartitions
case _ =>
false
}
}
object SparkSelfPartitioner {
def main(args: Array[String]): Unit = {
val conf=new SparkConf().setAppName("partitioner").setMaster("local")
val sc=new SparkContext(conf)
sc.setLogLevel("ERROR")
val data=sc.parallelize(List("http://www.baidu.com/aa/120","http://www.baidu.com/bb/110"))
data.map((_,1)).partitionBy(new SelfPartitioner(3)).foreachPartition(t => {
val id = TaskContext.get.partitionId
println("分区号:" + id)
t.foreach( one => {
println(one)
})
})
sc.stop()
}
}