[Spark基础]--spark自定义分区器及使用方法

Spark中分区器直接决定了RDD中分区的个数、RDD中每条数据经过Shuffle过程属于哪个分区和Reduce的个数

注意:

(1)只有Key-Value类型的RDD才有分区的,非Key-Value类型的RDD分区的值是None
(2)每个RDD的分区ID范围:0~numPartitions-1,决定这个值是属于那个分区的。

 

分区方式的优劣

 

HashPartitioner分区弊端:

可能导致每个分区中数据量的不均匀,极端情况下会导致某些分区拥有RDD的全部数据(HashCode为负数时,为了避免小于0,spark做了以下处理)。

/* Calculates 'x' modulo 'mod', takes to consideration sign of x,
 * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
 * so function return (x % mod) + mod in that case.
 */
 def nonNegativeMod(x: Int, mod: Int): Int = {
   val rawMod = x % mod
   rawMod + (if (rawMod < 0) mod else 0)
 }
RangePartitioner分区优势:尽量保证每个分区中数据量的均匀,而且分区与分区之间是有序的,一个分区中的元素肯定都是比另一个分区内的元素小或者大;

但是分区内的元素是不能保证顺序的。简单的说就是将一定范围内的数映射到某一个分区内。

 

一、三种分区方式介绍

1、默认分区方式(实际上是HashPartitioner)

  /**
   * Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
   *
   * If any of the RDDs already has a partitioner, choose that one.
   *
   * Otherwise, we use a default HashPartitioner. For the number of partitions, if
   * spark.default.parallelism is set, then we'll use the value from SparkContext
   * defaultParallelism, otherwise we'll use the max number of upstream partitions.
   *
   * Unless spark.default.parallelism is set, the number of partitions will be the
   * same as the number of partitions in the largest upstream RDD, as this should
   * be least likely to cause out-of-memory errors.
   *
   * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
   */
  def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
    val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
    for (r <- bysize="" if="" r="" partitioner="" isdefined="" r="" partitioner="" get="" numpartitions=""> 0) {
      return r.partitioner.get
    }
    if (rdd.context.conf.contains("spark.default.parallelism")) {
      new HashPartitioner(rdd.context.defaultParallelism)
    } else {
      new HashPartitioner(bySize.head.partitions.size)
    }
  }
 

2、HashPartitioner分区

HashPartitioner分区的原理:对于给定的key,计算其hashCode,并除于分区的个数取余,如果余数小于0,则用余数+分区的个数,最后返回的值就是这个key所属的分区ID。实现如下:

/**
 * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using
 * Java's `Object.hashCode`.
 *
 * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
 * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
 * produce an unexpected or incorrect result.
 */
class HashPartitioner(partitions: Int) extends Partitioner {
  require(partitions >= 0, s"Number of partitions ($partitions) cannot be negative.")

  def numPartitions: Int = partitions

  def getPartition(key: Any): Int = key match {
    case null => 0
    case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
  }

  override def equals(other: Any): Boolean = other match {
    case h: HashPartitioner =>
      h.numPartitions == numPartitions
    case _ =>
      false
  }

  override def hashCode: Int = numPartitions
}
 

3、RangePartitioner分区

RangePartitioner作用:将一定范围内的数映射到某一个分区内,在实现中,分界的算法尤为重要。算法对应的函数是rangeBounds。

代码如下:

/**
 * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
 * equal ranges. The ranges are determined by sampling the content of the RDD passed in.
 *
 * Note that the actual number of partitions created by the RangePartitioner might not be the same
 * as the `partitions` parameter, in the case where the number of sampled records is less than
 * the value of `partitions`.
 */
class RangePartitioner[K : Ordering : ClassTag, V](
    partitions: Int,
    rdd: RDD[_ <: product2="" k="" v="" private="" var="" ascending:="" boolean="true)" extends="" partitioner="" we="" allow="" partitions="0," which="" happens="" when="" sorting="" an="" empty="" rdd="" under="" the="" default="" settings="" require="" partitions="">= 0, s"Number of partitions cannot be negative but found $partitions.")

  private var ordering = implicitly[Ordering[K]]

  // An array of upper bounds for the first (partitions - 1) partitions
  private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      val sampleSize = math.min(20.0 * partitions, 1e6)
      // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.size).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
      if (numItems == 0L) {
        Array.empty
      } else {
        // If a partition contains much more than the average number of items, we re-sample from it
        // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.size).toFloat
            for (key <- sample="" candidates="" key="" weight="" if="" imbalancedpartitions="" nonempty="" re-sample="" imbalanced="" partitions="" with="" the="" desired="" sampling="" probability="" val="" imbalanced="new" partitionpruningrdd="" rdd="" map="" _="" _1="" imbalancedpartitions="" contains="" val="" seed="byteswap32(-rdd.id" -="" 1="" val="" resampled="imbalanced.sample(withReplacement" false="" fraction="" seed="" collect="" val="" weight="(1.0" fraction="" tofloat="" candidates="" resampled="" map="" x=""> (x, weight))
        }
        RangePartitioner.determineBounds(candidates, partitions)
      }
    }
  }

  def numPartitions: Int = rangeBounds.length + 1

  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

  def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

  override def equals(other: Any): Boolean = other match {
    case r: RangePartitioner[_, _] =>
      r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
    case _ =>
      false
  }

  override def hashCode(): Int = {
    val prime = 31
    var result = 1
    var i = 0
    while (i < rangeBounds.length) {
      result = prime * result + rangeBounds(i).hashCode
      i += 1
    }
    result = prime * result + ascending.hashCode
    result
  }

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => out.defaultWriteObject()
      case _ =>
        out.writeBoolean(ascending)
        out.writeObject(ordering)
        out.writeObject(binarySearch)

        val ser = sfactory.newInstance()
        Utils.serializeViaNestedStream(out, ser) { stream =>
          stream.writeObject(scala.reflect.classTag[Array[K]])
          stream.writeObject(rangeBounds)
        }
    }
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
    val sfactory = SparkEnv.get.serializer
    sfactory match {
      case js: JavaSerializer => in.defaultReadObject()
      case _ =>
        ascending = in.readBoolean()
        ordering = in.readObject().asInstanceOf[Ordering[K]]
        binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]

        val ser = sfactory.newInstance()
        Utils.deserializeViaNestedStream(in, ser) { ds =>
          implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
          rangeBounds = ds.readObject[Array[K]]()
        }
    }
  }
}
使用的分区算法是:水塘抽样,参考---》https://www.iteblog.com/archives/1525.html

 

二、自定义分区器

需要继承org.apache.spark.Partitioner类,实现如下:

import org.apache.spark.Partitioner
/**
  * Created by Jeff Yang on 2017/3/30
  * Update date:
  * Time: 18:03
  * Describle :
  * Result of Test:
  * Command:
  * Email: highfei2011@126.com
  */
 class  MySparkPartition(numParts: Int) extends Partitioner {

  override def numPartitions: Int = numParts

  /**
    * 可以自定义分区算法
    * @param key
    * @return
    */
  override def getPartition(key: Any): Int = {
    val domain = new java.net.URL(key.toString).getHost()
    val code = (domain.hashCode % numPartitions)
    if (code < 0) {
      code + numPartitions
    } else {
      code
    }
  }
  override def equals(other: Any): Boolean = other match {
    case mypartition: MySparkPartition =>
      mypartition.numPartitions == numPartitions
    case _ =>
      false
  }
  override def hashCode: Int = numPartitions

}
/**
 * 
 * 
 * def numPartitions:这个方法需要返回你想要创建分区的个数;
 * def getPartition:这个函数需要对输入的key做计算,然后返回该key的分区ID,范围一定是0到numPartitions-1;
 * equals():这个是Java标准的判断相等的函数,之所以要求用户实现这个函数是因为Spark内部会比较两个RDD的分区是否一样。
 * 
 * 
 * 
 * /
 

三、使用分区

创建了自定义分区后,使用方式如下:

import org.apache.spark.{SparkConf, SparkContext}

/**
  * Created by Jeff Yang on 2017/3/30
  * Update date:
  * Time: 18:47
  * Describle :使用自定义的分区器
  * Result of Test:
  * Command:
  * Email: highfei2011@126.com
  */
object UseMyPartitioner {

  def main(args: Array[String]) {
    val conf=new SparkConf()
      .setMaster("local[2]")
      .setAppName("TestMyParttioner")
      .set("spark.app.id","test-partition-id")
    val sc=new SparkContext(conf)

    //读取hdfs文件
    val lines=sc.textFile("hdfs://hadoop2:8020/user/test/word.txt")
    val splitMap=lines.flatMap(line=>line.split("\t")).map(word=>(word,2))//注意:RDD一定要是key-value   

    //保存
    splitMap.partitionBy(new MySparkPartition(3)).saveAsTextFile("F:/partrion/test")

    sc.stop()


  }

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值