Spark RDD之Partition

概要

Spark RDD主要由Dependency、Partition、Partitioner组成,Partition是其中之一。一份待处理的原始数据会被按照相应的逻辑(例如jdbc和hdfs的split逻辑)切分成n份,每份数据对应到RDD中的一个Partition,Partition的数量决定了task的数量,影响着程序的并行度,所以理解Partition是了解spark背后运行原理的第一步。

1. Partition定义

在这里插入图片描述

  • 查看spark源码,trait Partition的定义很简单,序列号index、hashCode以及equals方法。
  • Partition和RDD是伴生的,即每一种RDD都有其对应的Partition实现,所以,分析Partition主要是分析其子类。
  • 我们关注两个常用的子类,JdbcPartitionHadoopPartition
    此外,RDD源码中有5个方法,代表其组成,如下:
    在这里插入图片描述
    第二个方法,getPartitions是数据源如何被切分的逻辑,返回值正是Partition,第一个方法compute是消费切割后的Partition的方法,所以学习Partition,要结合getPartitions和compute方法。

2. JdbcPartition例子:

下面是Spark JdbcRDDSuite中一个例子 :

val sc = new SparkContext("local[1]", "test") 
val rdd = new JdbcRDD( 
	sc, () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, 
// DATA类型为INTEGER 
	"SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", 1, 100, 3, 
	(r: ResultSet) => { r.getInt(1)})
	.count()

在这里插入图片描述
查看JdbcPartition实现,相比Partition,主要多了lower和upper这两个字段。
在这里插入图片描述
查看JdbcRDD的getPartitions,按照如上图所示算法将1到100分为3份(partition数量),结果为(1,33)、(34,66)、(67,100),封装为JdbcPartition并返回,这样数据切分的部分就完成了。
在这里插入图片描述
查看JdbcRDD的compute方法,逻辑清晰,将Partition强转为JdbcPartition,获取连接并预处理sql
将 例子中的”SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?”问号分别用Partition的lower和upper替换(即getPartitions切分好的(1,33)、(34,66)、(67,100))并执行查询。
至此,JdbcPartition如何发挥作用就分析完了。

3. HadoopPartition例子

举个简单例子

val sc = new SparkContext("local[1]", "test")
sc.textFile("hdfs://your-file-path").count()


相比Partition,HadoopPartition则多了InputSplit

在这里插入图片描述
spark切分hdfs文件,调用的是Hadoop的API,对这块不熟的同学查看上面InputSplit的链接。

在这里插入图片描述
执行计算的逻辑也很简单,将Partition强转为HadoopPartition,HadoopPartition内有InputSplit对象。调用Hadoop API三个读取数据的相关对象,InputSplit、InputFormat和Reader,读取对应split的数据。这块需要你对Hadoop的掌握,另外我在下面会讲Hadoop split的策略。

4. 决定partition数量的因素

  • Partition数量可以在初始化RDD时指定(如JdbcPartition例子)
  • 不指定的话(如HadoopPartition例子),则 读取spark.default.parallelism配置,不同类型资源管理器取值不同,如下
    在这里插入图片描述
    了解了默认的partition数量,再看一些具体API的partition行为

RDD初始化相关

在这里插入图片描述

通用transformation

在这里插入图片描述

Key-based Transformations

在这里插入图片描述
…未完待续

Partition数量影响及调整

上面分析了决定Partition数量的因数,接下来就该考虑Partition数量的影响以及合适的值。

Partition数量的影响

  1. Partition数量太少
    太少的影响显而易见,就是资源不能充分利用,例如local模式下,有16core,但是Partition数量仅为8的话,有一半的core没利用到。
  2. Partition数量太多
    太多,资源利用没什么问题,但是导致task过多,task的序列化和传输的时间开销增大。
    那么多少的partition数是合适的呢,这里我们参考spark doc给出的建议,Typically you want 2-4 partitions for each CPU in your cluster

Partition调整

  1. repartition
    reparation是coalesce(numPartitions, shuffle = true),repartition不仅会调整Partition数,也会将Partitioner修改为hashPartitioner,产生shuffle操作。
  2. coalesce
    coalesce函数可以控制是否shuffle,但当shuffle为false时,只能减小Partition数,无法增大。

总结

Partition对应的是不同数据源的split逻辑,首先以JdbcPartition和HadoopPartition为例,介绍了Partition的组成,以及如何发挥作用,接下来分析了常见API的Partition行为,最后简单介绍了Partition数量的影响及调整。

附录

-----------------------------hadoopRDD.scala--------------------
/**
 * 包含Hadoop InputSplit的Spark拆分类。
 */
private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit)
  extends Partition {

  val inputSplit = new SerializableWritable[InputSplit](s)

  override def hashCode(): Int = 31 * (31 + rddId) + index

  override def equals(other: Any): Boolean = super.equals(other)

  /**
		* 获取运行管道时应添加到用户环境的任何环境变量
    * @return 带有环境变量和相应值的Map,它可能为空
   */
  def getPipeEnvVars(): Map[String, String] = {
    val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
      val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
      //不推荐使用map_input_file来支持mapreduce_map_input_file,但是因为它还没有删除所以都设置了  
      Map("map_input_file" -> is.getPath().toString(),
        "mapreduce_map_input_file" -> is.getPath().toString())
    } else {
      Map()
    }
    envVars
  }
}

  override def getPartitions: Array[Partition] = {
    val jobConf = getJobConf()
    // add the credentials here as this can be called before SparkContext initialized
    SparkHadoopUtil.get.addCredentials(jobConf)
    try {
      val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
      val inputSplits = if (ignoreEmptySplits) {
        allInputSplits.filter(_.getLength > 0)
      } else {
        allInputSplits
      }
      val array = new Array[Partition](inputSplits.size)
      for (i <- 0 until inputSplits.size) {
        array(i) = new HadoopPartition(id, i, inputSplits(i))
      }
      array
    } catch {
      case e: InvalidInputException if ignoreMissingFiles =>
        logWarning(s"${jobConf.get(FileInputFormat.INPUT_DIR)} doesn't exist and no" +
            s" partitions returned from this path.", e)
        Array.empty[Partition]
    }
  }

  override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
    val iter = new NextIterator[(K, V)] {

      private val split = theSplit.asInstanceOf[HadoopPartition]
      logInfo("Input split: " + split.inputSplit)
      private val jobConf = getJobConf()

      private val inputMetrics = context.taskMetrics().inputMetrics
      private val existingBytesRead = inputMetrics.bytesRead

      // Sets InputFileBlockHolder for the file block's information
      split.inputSplit.value match {
        case fs: FileSplit =>
          InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
        case _ =>
          InputFileBlockHolder.unset()
      }

      // Find a function that will return the FileSystem bytes read by this thread. Do this before
      // creating RecordReader, because RecordReader's constructor might read some bytes
      private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
        case _: FileSplit | _: CombineFileSplit =>
          Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback())
        case _ => None
      }

      // We get our input bytes from thread-local Hadoop FileSystem statistics.
      // If we do a coalesce, however, we are likely to compute multiple partitions in the same
      // task and in the same thread, in which case we need to avoid override values written by
      // previous partitions (SPARK-13071).
      private def updateBytesRead(): Unit = {
        getBytesReadCallback.foreach { getBytesRead =>
          inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
        }
      }

      private var reader: RecordReader[K, V] = null
      private val inputFormat = getInputFormat(jobConf)
      HadoopRDD.addLocalConfiguration(
        new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
        context.stageId, theSplit.index, context.attemptNumber, jobConf)

      reader =
        try {
          inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
        } catch {
          case e: FileNotFoundException if ignoreMissingFiles =>
            logWarning(s"Skipped missing file: ${split.inputSplit}", e)
            finished = true
            null
          // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
          case e: FileNotFoundException if !ignoreMissingFiles => throw e
          case e: IOException if ignoreCorruptFiles =>
            logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
            finished = true
            null
        }
      // Register an on-task-completion callback to close the input stream.
      context.addTaskCompletionListener[Unit] { context =>
        // Update the bytes read before closing is to make sure lingering bytesRead statistics in
        // this thread get correctly added.
        updateBytesRead()
        closeIfNeeded()
      }

      private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
      private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()

      override def getNext(): (K, V) = {
        try {
          finished = !reader.next(key, value)
        } catch {
          case e: FileNotFoundException if ignoreMissingFiles =>
            logWarning(s"Skipped missing file: ${split.inputSplit}", e)
            finished = true
          // Throw FileNotFoundException even if `ignoreCorruptFiles` is true
          case e: FileNotFoundException if !ignoreMissingFiles => throw e
          case e: IOException if ignoreCorruptFiles =>
            logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
            finished = true
        }
        if (!finished) {
          inputMetrics.incRecordsRead(1)
        }
        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
          updateBytesRead()
        }
        (key, value)
      }

      override def close(): Unit = {
        if (reader != null) {
          InputFileBlockHolder.unset()
          try {
            reader.close()
          } catch {
            case e: Exception =>
              if (!ShutdownHookManager.inShutdown()) {
                logWarning("Exception in RecordReader.close()", e)
              }
          } finally {
            reader = null
          }
          if (getBytesReadCallback.isDefined) {
            updateBytesRead()
          } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                     split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
            // If we can't get the bytes read from the FS stats, fall back to the split size,
            // which may be inaccurate.
            try {
              inputMetrics.incBytesRead(split.inputSplit.value.getLength)
            } catch {
              case e: java.io.IOException =>
                logWarning("Unable to get input size to set InputMetrics for task", e)
            }
          }
        }
      }
    }
    new InterruptibleIterator[(K, V)](context, iter)
  }


-------------------------------jdbcRDD.scala-----------------
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
  override def index: Int = idx
}

class JdbcRDD[T: ClassTag](
    sc: SparkContext,
    getConnection: () => Connection,
    sql: String,
    lowerBound: Long,
    upperBound: Long,
    numPartitions: Int,
    mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
  extends RDD[T](sc, Nil) with Logging {

  override def getPartitions: Array[Partition] = {
    // bounds are inclusive, hence the + 1 here and - 1 on end
    val length = BigInt(1) + upperBound - lowerBound
    (0 until numPartitions).map { i =>
      val start = lowerBound + ((i * length) / numPartitions)
      val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
      new JdbcPartition(i, start.toLong, end.toLong)
    }.toArray
  }


  override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
  {
    context.addTaskCompletionListener[Unit]{ context => closeIfNeeded() }
    val part = thePart.asInstanceOf[JdbcPartition]
    val conn = getConnection()
    val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)

    val url = conn.getMetaData.getURL
    if (url.startsWith("jdbc:mysql:")) {
        // setFetchSize(Integer.MIN_VALUE)是一种特定于mysql驱动程序的强制流式结果,而不是将整个结果集拉入内存。
        // 请参阅:dev.mysql.com/doc/connector-j/5.1/en/connector-j-reference-implementation-notes.html
      stmt.setFetchSize(Integer.MIN_VALUE)
    } else {
      stmt.setFetchSize(100)
    }
    logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")

    stmt.setLong(1, part.lower)
    stmt.setLong(2, part.upper)
    val rs = stmt.executeQuery()

    override def getNext(): T = {
      if (rs.next()) {
        mapRow(rs)
      } else {
        finished = true
        null.asInstanceOf[T]
      }
    }

    override def close() {
      try {
        if (null != rs) {
          rs.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing resultset", e)
      }
      try {
        if (null != stmt) {
          stmt.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing statement", e)
      }
      try {
        if (null != conn) {
          conn.close()
        }
        logInfo("closed connection")
      } catch {
        case e: Exception => logWarning("Exception closing connection", e)
      }
    }
  }
}

--------------------------------------



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值