spark RDD 源码阅读笔记

abstract class RDD[T: ClassTag](
    @transient private var _sc: SparkContext,
    @transient private var deps: Seq[Dependency[_]]
  ) extends Serializable with Logging

从上面RDD的类定义来看,创建一个RDD需要的是 sparkContext和它的依赖。


每一个RDD都会有5个特征:

1、有一个分片列表。就是能被切分,和hadoop一样的,能够切分的数据才能并行计算。

@transient private var partitions_ : Array[Partition] = null

2、有一个函数计算每一个分片,这里指的是下面会提到的compute函数。

@DeveloperApi
  def compute(split: Partition, context: TaskContext): Iterator[T]

3、对其他的RDD的依赖列表,依赖还具体分为宽依赖和窄依赖,但并不是所有的RDD都有依赖。

@transient private var deps: Seq[Dependency[_]]

// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
  // be overwritten when we're checkpointed
  private var dependencies_ : Seq[Dependency[_]] = null
  @transient private var partitions_ : Array[Partition] = null

  /**
   * Get the list of dependencies of this RDD, taking into account whether the
   * RDD is checkpointed or not.
   */
  final def dependencies: Seq[Dependency[_]] = {
    checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
      if (dependencies_ == null) {
        dependencies_ = getDependencies
      }
      dependencies_
    }
  }


有两个位置有。它们是不一样的。当调用子类的方法时,它们就不见了。

4、可选:key-value型的RDD是根据哈希来分区的,类似于mapreduce当中的Paritioner接口,控制key分到哪个reduce。

5、可选:每一个分片的优先计算位置(preferred locations),比如HDFS的block的所在位置应该是优先计算的位置。

  /**
   * Optionally overridden by subclasses to specify placement preferences.
   */
  protected def getPreferredLocations(split: Partition): Seq[String] = Nil

  /** Optionally overridden by subclasses to specify how they are partitioned. */
  @transient val partitioner: Option[Partitioner] = None

名字和ID等信息的设置:

  /** A unique ID for this RDD (within its SparkContext). */
  val id: Int = sc.newRddId()

  /** A friendly name for this RDD */
  @transient var name: String = null

  /** Assign a name to this RDD */
  def setName(_name: String): this.type = {
    name = _name
    this
  }



如何得到一个RDD的分区:通过getPartitions函数

  /**
   * Implemented by subclasses to return the set of partitions in this RDD. This method will only
   * be called once, so it is safe to implement a time-consuming computation in it.
   *
   * The partitions in this array must satisfy the following property:
   *   `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
   */
  protected def getPartitions: Array[Partition]

可以看到这个函数需要子类来实现,并且只计算一次。getPartitions返回的是一系列partitions的集合,即一个Partition类型的数组。

下面进入HadoopRDD看看。

  override def getPartitions: Array[Partition] = {
    val jobConf = getJobConf()
    // add the credentials here as this can be called before SparkContext initialized
    SparkHadoopUtil.get.addCredentials(jobConf)
    val inputFormat = getInputFormat(jobConf)
    val inputSplits = inputFormat.getSplits(jobConf, minPartitions)
    val array = new Array[Partition](inputSplits.size)
    for (i <- 0 until inputSplits.size) {
      array(i) = new HadoopPartition(id, i, inputSplits(i))
    }
    array
  }

以上是HadoopRDD的getPartitions方法,

它调用的是inputFormat自带的getSplits方法来计算分片,然后把分片HadoopPartition包装到到array里面返回。

这里顺便顺带提一下,因为1.0又出来一个NewHadoopRDD,它使用的是mapreduce新api的inputformat,getSplits就不要有minPartitions了,别的逻辑都是一样的,只是使用的类有点区别。

得到一个分片的数据:compute方法。

  override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
    val iter = new NextIterator[(K, V)] {

      val split = theSplit.asInstanceOf[HadoopPartition]
      logInfo("Input split: " + split.inputSplit)
      val jobConf = getJobConf()

      // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD

      val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
      val existingBytesRead = inputMetrics.bytesRead

      // Sets the thread local variable for the file's name
      split.inputSplit.value match {
        case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
        case _ => SqlNewHadoopRDDState.unsetInputFileName()
      }

      // Find a function that will return the FileSystem bytes read by this thread. Do this before
      // creating RecordReader, because RecordReader's constructor might read some bytes
      val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
        case _: FileSplit | _: CombineFileSplit =>
          SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
        case _ => None
      }

      // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
      // If we do a coalesce, however, we are likely to compute multiple partitions in the same
      // task and in the same thread, in which case we need to avoid override values written by
      // previous partitions (SPARK-13071).
      def updateBytesRead(): Unit = {
        getBytesReadCallback.foreach { getBytesRead =>
          inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
        }
      }

      var reader: RecordReader[K, V] = null
      val inputFormat = getInputFormat(jobConf)
      HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
        context.stageId, theSplit.index, context.attemptNumber, jobConf)
      reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)

      // Register an on-task-completion callback to close the input stream.
      context.addTaskCompletionListener{ context => closeIfNeeded() }
      val key: K = reader.createKey()
      val value: V = reader.createValue()

      override def getNext(): (K, V) = {
        try {
          finished = !reader.next(key, value)
        } catch {
          case eof: EOFException =>
            finished = true
        }
        if (!finished) {
          inputMetrics.incRecordsReadInternal(1)
        }
        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
          updateBytesRead()
        }
        (key, value)
      }

      override def close() {
        if (reader != null) {
          SqlNewHadoopRDDState.unsetInputFileName()
          // Close the reader and release it. Note: it's very important that we don't close the
          // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
          // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
          // corruption issues when reading compressed input.
          try {
            reader.close()
          } catch {
            case e: Exception =>
              if (!ShutdownHookManager.inShutdown()) {
                logWarning("Exception in RecordReader.close()", e)
              }
          } finally {
            reader = null
          }
          if (getBytesReadCallback.isDefined) {
            updateBytesRead()
          } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                     split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
            // If we can't get the bytes read from the FS stats, fall back to the split size,
            // which may be inaccurate.
            try {
              inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength)
            } catch {
              case e: java.io.IOException =>
                logWarning("Unable to get input size to set InputMetrics for task", e)
            }
          }
        }
      }
    }
    new InterruptibleIterator[(K, V)](context, iter)
  }

我们接下来看compute方法,它的输入值是一个Partition,返回是一个Iterator[(K, V)]类型的数据,这里面我们只需要关注2点即可。

1、把Partition转成HadoopPartition,然后通过InputSplit创建一个RecordReader

2、重写Iterator的getNext方法,通过创建的reader调用next方法读取下一个值。


从这里我们可以看得出来compute方法是通过分片来获得Iterator接口,以遍历分片的数据。


再回到RDD中,接下来我们就看RDD上的一些操作了。最简单的RDD的map操作。

  /**
   * Return a new RDD by applying a function to all elements of this RDD.
   */
  def map[U: ClassTag](f: T => U): RDD[U] = withScope {
    val cleanF = sc.clean(f)
    new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
  }
输入是一个函数,返回是一个MapPartitionsRDD,而MapPartitionsRDD的传入参数是RDD本身,以及一个方法 (context, pid, iter) => iter.map(cleanF)

现在就看MapPartitionsRDD了,

/**
 * An RDD that applies the provided function to every partition of the parent RDD.
 */
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    var prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false)
  extends RDD[U](prev)

并且它重写了两个重要的方法:

  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

  override def getPartitions: Array[Partition] = firstParent[T].partitions

  override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

用到了firstParent[T],这个firstParent是何须人也?我们可以先点击进入RDD[U](prev)这个构造函数里面去。

  /** Construct an RDD with just a one-to-one dependency on one parent */
  def this(@transient oneParent: RDD[_]) =
    this(oneParent.context, List(new OneToOneDependency(oneParent)))

真相就在这里了,它把父RDD的context和dependency。父RDD成了MapPartitionsRDD的父依赖了,这个OneToOneDependency是一个窄依赖,子RDD直接依赖于父RDD,继续看firstParent。

  /** Returns the first parent RDD */
  protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
    dependencies.head.rdd.asInstanceOf[RDD[U]]
  }

由此我们可以得出两个结论:

1、getPartitions直接沿用了父RDD的分片信息

2、compute函数是在父RDD遍历每一行数据时套一个匿名函数f进行处理

好吧,现在我们可以理解compute函数真正是在干嘛的了

它的两个显著作用:

1、在没有依赖的条件下,根据分片的信息生成遍历数据的Iterable接口

2、在有前置依赖的条件下,在父RDD的Iterable接口上给遍历每个元素的时候再套上一个方法


继续追踪:

  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

firstParent的iterator返回的是Iterator对象,同时调用了computeOrReadCheckpoint

  /**
   * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
   */
  private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointedAndMaterialized) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }

computeOrReadCheckpoint调用了compute返回了Iterator[T]返回了指定分区的数据,并被函数f作用。完美的循环。这样就记住了计算,而不是数据。


而对于map来说,在Iterator类里有这个函数,同样的,Interator也有像flatmap,filter这样的类,

  def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] {
    def hasNext = self.hasNext
    def next() = f(self.next())  }


而filter操作和flatmap也是同样的步骤。


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值