自定义分区器
HashPartitioner源码解读:
/*class HashPartitioner(partitions: Int) extends Partitioner {
// 传进来的分区个数必须是大于等于0的,不然它会报错
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
// 重写分区器的抽象方法
// 记录它有多少个分区 就是外面传进来的参数
def numPartitions: Int = partitions
//
def getPartition(key: Any): Int = key match {
case null => 0
// 根据传进来的一个int值,返回分区号
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
// Object里面的方法,用来判断这个分区器是否一样,用来判断hashcode值的
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
override def hashCode: Int = numPartitions
}*/
代码实现:
package com.huc.Spark1.KeyAndValue
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.{HashPartitioner, Partitioner, SparkConf, SparkContext}
object Test02_CustomPartitioner {
def main(args: Array[String]): Unit = {
//1.创建SparkConf并设置App名称
val conf: SparkConf = new SparkConf().setAppName("SparkCore").setMaster("local[*]")
//2.创建SparkContext,该对象是提交Spark App的入口
val sc: SparkContext = new SparkContext(conf)
//3.使用Scala进行spark编程
val rdd: RDD[Int] = sc.makeRDD(List(1, 2, 3, 4, 5), 3) // 0,1 1,3 5
println(rdd.mapPartitionsWithIndex((index, datas) => datas.map((index, _))).collect().mkString(","))
val value: RDD[(Int, Int)] = rdd.map((_, 1))
val value1: RDD[(Int, Int)] = value.partitionBy(new MyPartitioner(2))
val value2: RDD[(Int, (Int, Int))] = value1.mapPartitionsWithIndex((index, datas) => datas.map(data => (index, data)))
value2.collect().foreach(println)
//4.关闭连接
sc.stop()
}
class MyPartitioner(partitions: Int) extends Partitioner {
// 分区器有几个分区
override def numPartitions: Int = partitions
// 使用返回值 确定分区分到哪个分区
// 在spark中分区器的默认逻辑是只能使用key对它进行分区,不能使用value分区
// 重写的方法,参数类型不可以改变
override def getPartition(key: Any): Int = {
// 使用模式匹配匹配类型
key match {
case i: Int => if (i <= 3) 0 else 1
case _ => 0
}
}
}
/*class HashPartitioner(partitions: Int) extends Partitioner {
// 传进来的分区个数必须是大于等于0的,不然它会报错
require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")
// 重写分区器的抽象方法
// 记录它有多少个分区 就是外面传进来的参数
def numPartitions: Int = partitions
//
def getPartition(key: Any): Int = key match {
case null => 0
// 根据传进来的一个int值,返回分区号
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
// Object里面的方法,用来判断这个分区器是否一样,用来判断hashcode值的
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
override def hashCode: Int = numPartitions
}*/
}