spark sql代码生成

spark代码生成

spark sql 逻辑计划 ->物理执行计划->生成 spark rdd 代码 执行

spark plan 与火山模型

火山模型 Volcano

https://zhuanlan.zhihu.com/p/478851521
不断调用next 获取数据处理返回.

spark plan

sparkPlan是物理执行计划,可以直接执行

org.apache.spark.sql.execution.SparkPlan
  /**
   * Produces the result of the query as an `RDD[InternalRow]`
   *
   * Overridden by concrete implementations of SparkPlan.
   */
  protected def doExecute(): RDD[InternalRow]
例子 filter

org.apache.spark.sql.execution.FilterExec

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
      val predicate = Predicate.create(condition, child.output)
      predicate.initialize(0)
      iter.filter { row =>
        val r = predicate.eval(row)
        if (r) numOutputRows += 1
        r
      }
    }
  }

可以发现也是通过spark rdd来执行的.

例子 range

org.apache.spark.sql.execution.RangeExec

  protected override def doExecute(): RDD[InternalRow] = {
    val numOutputRows = longMetric("numOutputRows")
    if (isEmptyRange) {
      new EmptyRDD[InternalRow](sqlContext.sparkContext)
    } else {
      sqlContext
        .sparkContext
        .parallelize(0 until numSlices, numSlices)
        .mapPartitionsWithIndex { (i, _) =>
          val partitionStart = (i * numElements) / numSlices * step + start
          val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start

          def getSafeMargin(bi: BigInt): Long =
            if (bi.isValidLong) {
              bi.toLong
            } else if (bi > 0) {
              Long.MaxValue
            } else {
              Long.MinValue
            }

          val safePartitionStart = getSafeMargin(partitionStart)
          val safePartitionEnd = getSafeMargin(partitionEnd)
          val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
          val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
          val taskContext = TaskContext.get()

          val iter = new Iterator[InternalRow] {
            private[this] var number: Long = safePartitionStart
            private[this] var overflow: Boolean = false
            private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics

            override def hasNext =
              if (!overflow) {
                if (step > 0) {
                  number < safePartitionEnd
                } else {
                  number > safePartitionEnd
                }
              } else false

            override def next() = {
              val ret = number
              number += step
              if (number < ret ^ step < 0) {
                // we have Long.MaxValue + Long.MaxValue < Long.MaxValue
                // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
                // back, we are pretty sure that we have an overflow.
                overflow = true
              }

              numOutputRows += 1
              inputMetrics.incRecordsRead(1)
              unsafeRow.setLong(0, ret)
              unsafeRow
            }
          }
          new InterruptibleIterator(taskContext, iter)
        }
    }
  }

火山模型的缺陷及全代码生成

Whole-stage code generation was introduced in Spark 2.0 as part of the tungsten engine. And it was inspired by Thomas Newman’s paper; “Efficiently Compiling Efficient Grade Plans For Modern Hardware.”

执行应以数据为中心, 而不是以算子为中心, 应该打破算子间的调用边界, 尽可能将处理的数据保持在寄存器中.
数据不
应该使用 iterator pull 的方式, 而应该 push 到上层算子, 以获得更好的代码/数据局部性.



1.火山模型不断调用 next ,有过多的虚函数调用.
2.虚函数的调用 导致数据无法缓存在寄存器中.

1~100 对 基数求和
火山模型:
Array.filter(x%2==1).sum()

代码生成
int sum=0;
for(int i=0;i<100;i++){
	if(i%2==1){
		sum+=i;
	}
}

具体实现

https://issues.apache.org/jira/browse/SPARK-12795
打开这个issue.可以看到具体的代码全代码生成是由一个人来完成的.

Benchmark

maven
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-sql_2.12</artifactId>
        <version>3.1.2</version>
    </dependency>

Benchmark
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package sparkbenchmark

import java.io.{OutputStream, PrintStream}

import org.apache.commons.io.output.TeeOutputStream
import org.apache.commons.lang3.SystemUtils
import org.apache.spark.util.Utils

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.util.Try

/**
 * Utility class to benchmark components. An example of how to use this is:
 *  val benchmark = new Benchmark("My Benchmark", valuesPerIteration)
 *   benchmark.addCase("V1")(<function>)
 *   benchmark.addCase("V2")(<function>)
 *   benchmark.run
 * This will output the average time to run each function and the rate of each function.
 *
 * The benchmark function takes one argument that is the iteration that's being run.
 *
 * @param name name of this benchmark.
 * @param valuesPerIteration number of values used in the test case, used to compute rows/s.
 * @param minNumIters the min number of iterations that will be run per case, not counting warm-up.
 * @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up.
 * @param minTime further iterations will be run for each case until this time is used up.
 * @param outputPerIteration if true, the timing for each run will be printed to stdout.
 * @param output optional output stream to write benchmark results to
 */
 class Benchmark(
    name: String,
    valuesPerIteration: Long,
    minNumIters: Int = 2,
    warmupTime: FiniteDuration = 2.seconds,
    minTime: FiniteDuration = 2.seconds,
    outputPerIteration: Boolean = false,
    output: Option[OutputStream] = None) {
  import Benchmark._
  val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]

  val out = if (output.isDefined) {
    new PrintStream(new TeeOutputStream(System.out, output.get))
  } else {
    System.out
  }

  /**
   * Adds a case to run when run() is called. The given function will be run for several
   * iterations to collect timing statistics.
   *
   * @param name of the benchmark case
   * @param numIters if non-zero, forces exactly this many iterations to be run
   */
  def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
    addTimerCase(name, numIters) { timer =>
      timer.startTiming()
      f(timer.iteration)
      timer.stopTiming()
    }
  }

  /**
   * Adds a case with manual timing control. When the function is run, timing does not start
   * until timer.startTiming() is called within the given function. The corresponding
   * timer.stopTiming() method must be called before the function returns.
   *
   * @param name of the benchmark case
   * @param numIters if non-zero, forces exactly this many iterations to be run
   */
  def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = {
    benchmarks += Benchmark.Case(name, f, numIters)
  }

  /**
   * Runs the benchmark and outputs the results to stdout. This should be copied and added as
   * a comment with the benchmark. Although the results vary from machine to machine, it should
   * provide some baseline.
   */
  def run(): Unit = {
    require(benchmarks.nonEmpty)
    // scalastyle:off
    println("Running benchmark: " + name)

    val results = benchmarks.map { c =>
      println("  Running case: " + c.name)
      measure(valuesPerIteration, c.numIters)(c.fn)
    }
    println

    val firstBest = results.head.bestMs
    // The results are going to be processor specific so it is useful to include that.
    out.println(Benchmark.getJVMOSInfo())
    out.println(Benchmark.getProcessorName())
    val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max))
    out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
      name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", "Relative")
    out.println("-" * (nameLen + 80))
    results.zip(benchmarks).foreach { case (result, benchmark) =>
      out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
        benchmark.name,
        "%5.0f" format result.bestMs,
        "%4.0f" format result.avgMs,
        "%5.0f" format result.stdevMs,
        "%10.1f" format result.bestRate,
        "%6.1f" format (1000 / result.bestRate),
        "%3.1fX" format (firstBest / result.bestMs))
    }
    out.println
    // scalastyle:on
  }

  /**
   * Runs a single function `f` for iters, returning the average time the function took and
   * the rate of the function.
   */
  def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
    System.gc()  // ensures garbage from previous cases don't impact this one
    val warmupDeadline = warmupTime.fromNow
    while (!warmupDeadline.isOverdue) {
      f(new Benchmark.Timer(-1))
    }
    val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
    val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
    val runTimes = ArrayBuffer[Long]()
    var totalTime = 0L
    var i = 0
    while (i < minIters || totalTime < minDuration) {
      val timer = new Benchmark.Timer(i)
      f(timer)
      val runTime = timer.totalTime()
      runTimes += runTime
      totalTime += runTime

      if (outputPerIteration) {
        // scalastyle:off
        println(s"Iteration $i took ${NANOSECONDS.toMicros(runTime)} microseconds")
        // scalastyle:on
      }
      i += 1
    }
    // scalastyle:off
    println(s"  Stopped after $i iterations, ${NANOSECONDS.toMillis(runTimes.sum)} ms")
    // scalastyle:on
    assert(runTimes.nonEmpty)
    val best = runTimes.min
    val avg = runTimes.sum / runTimes.size
    val stdev = if (runTimes.size > 1) {
      math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum / (runTimes.size - 1))
    } else 0
    Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0, stdev / 1000000.0)
  }
}

object Benchmark {

  /**
   * Object available to benchmark code to control timing e.g. to exclude set-up time.
   *
   * @param iteration specifies this is the nth iteration of running the benchmark case
   */
  class Timer(val iteration: Int) {
    private var accumulatedTime: Long = 0L
    private var timeStart: Long = 0L

    def startTiming(): Unit = {
      assert(timeStart == 0L, "Already started timing.")
      timeStart = System.nanoTime
    }

    def stopTiming(): Unit = {
      assert(timeStart != 0L, "Have not started timing.")
      accumulatedTime += System.nanoTime - timeStart
      timeStart = 0L
    }

    def totalTime(): Long = {
      assert(timeStart == 0L, "Have not stopped timing.")
      accumulatedTime
    }
  }

  case class Case(name: String, fn: Timer => Unit, numIters: Int)
  case class Result(avgMs: Double, bestRate: Double, bestMs: Double, stdevMs: Double)

  /**
   * This should return a user helpful processor information. Getting at this depends on the OS.
   * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz"
   */
  def getProcessorName(): String = {
//    val cpu = if (SystemUtils.IS_OS_MAC_OSX) {
//      Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string"))
//        .stripLineEnd
//    } else if (SystemUtils.IS_OS_LINUX) {
//      Try {
//        val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")).stripLineEnd
//        Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo"))
//          .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "")
//      }.getOrElse("Unknown processor")
//    } else {
//      System.getenv("PROCESSOR_IDENTIFIER")
//    }
    "cpu"
  }

  /**
   * This should return a user helpful JVM & OS information.
   * This should return something like
   * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64"
   */
  def getJVMOSInfo(): String = {
    val vmName = System.getProperty("java.vm.name")
    val runtimeVersion = System.getProperty("java.runtime.version")
    val osName = System.getProperty("os.name")
    val osVersion = System.getProperty("os.version")
    s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
  }
}

Test
object Testxx {
  def testWholeStage(values: Int): Unit = {
    val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
    val sc = SparkContext.getOrCreate(conf)
    val sqlContext = SQLContext.getOrCreate(sc)

    val benchmark = new Benchmark("Single Int Column Scan", values)

    benchmark.addCase("Without whole stage codegen") { iter =>
      sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
      sqlContext.range(values).filter("(id & 1) = 1").count()
    }

    benchmark.addCase("With whole stage codegen") { iter =>
      sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
      sqlContext.range(values).filter("(id & 1) = 1").count()
    }

    /*
      Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
      Single Int Column Scan:      Avg Time(ms)    Avg Rate(M/s)  Relative Rate
      -------------------------------------------------------------------------
      Without whole stage codegen       6725.52            31.18         1.00 X
      With whole stage codegen          2233.05            93.91         3.01 X
    */
    benchmark.run()
  }

  def main(args: Array[String]): Unit = {
    testWholeStage(1024 * 1024 * 200)
  }
}

控制参数
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")

具体代码分析

切换分支到

SPARK-14722 对应的版本
在这里插入图片描述
在这里插入图片描述

spark 代码生成 pull request

https://github.com/apache/spark/pull/10735
https://www.databricks.com/session_na20/understanding-and-improving-code-generation

A SparkPlan that support codegen need to implement doProduce() and doConsume():

def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
trait CodegenSupport extends SparkPlan {

  /** Prefix used in the current operator's variable names. */
  private def variablePrefix: String = this match {
    case _: TungstenAggregate => "agg"
    case _: BroadcastHashJoin => "bhj"
    case _: SortMergeJoin => "smj"
    case _: PhysicalRDD => "rdd"
    case _: DataSourceScan => "scan"
    case _ => nodeName.toLowerCase
  }

  /**
   * Creates a metric using the specified name.
   *
   * @return name of the variable representing the metric
   */
  def metricTerm(ctx: CodegenContext, name: String): String = {
    val metric = ctx.addReferenceObj(name, longMetric(name))
    val value = ctx.freshName("metricValue")
    val cls = classOf[LongSQLMetricValue].getName
    ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
    value
  }

  /**
   * Whether this SparkPlan support whole stage codegen or not.
   */
  def supportCodegen: Boolean = true

  /**
   * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
   */
  protected var parent: CodegenSupport = null

  /**
   * Returns all the RDDs of InternalRow which generates the input rows.
   *
   * Note: right now we support up to two RDDs.
   */
  def inputRDDs(): Seq[RDD[InternalRow]]

  /**
   * Returns Java source code to process the rows from input RDD.
   */
  final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
    this.parent = parent
    ctx.freshNamePrefix = variablePrefix
    waitForSubqueries()
    s"""
       |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
       |${doProduce(ctx)}
     """.stripMargin
  }

  /**
   * Generate the Java source code to process, should be overridden by subclass to support codegen.
   *
   * doProduce() usually generate the framework, for example, aggregation could generate this:
   *
   *   if (!initialized) {
   *     # create a hash map, then build the aggregation hash map
   *     # call child.produce()
   *     initialized = true;
   *   }
   *   while (hashmap.hasNext()) {
   *     row = hashmap.next();
   *     # build the aggregation results
   *     # create variables for results
   *     # call consume(), which will call parent.doConsume()
   *      if (shouldStop()) return;
   *   }
   */
  protected def doProduce(ctx: CodegenContext): String

  /**
   * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
   */
  final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
    val inputVars =
      if (row != null) {
        ctx.currentVars = null
        ctx.INPUT_ROW = row
        output.zipWithIndex.map { case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
        }
      } else {
        assert(outputVars != null)
        assert(outputVars.length == output.length)
        // outputVars will be used to generate the code for UnsafeRow, so we should copy them
        outputVars.map(_.copy())
      }

    val rowVar = if (row != null) {
      ExprCode("", "false", row)
    } else {
      if (outputVars.nonEmpty) {
        val colExprs = output.zipWithIndex.map { case (attr, i) =>
          BoundReference(i, attr.dataType, attr.nullable)
        }
        val evaluateInputs = evaluateVariables(outputVars)
        // generate the code to create a UnsafeRow
        ctx.currentVars = outputVars
        val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
        val code = s"""
          |$evaluateInputs
          |${ev.code.trim}
         """.stripMargin.trim
        ExprCode(code, "false", ev.value)
      } else {
        // There is no columns
        ExprCode("", "false", "unsafeRow")
      }
    }

    ctx.freshNamePrefix = parent.variablePrefix
    val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
    s"""
       |
       |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
       |$evaluated
       |${parent.doConsume(ctx, inputVars, rowVar)}
     """.stripMargin
  }

  /**
   * Returns source code to evaluate all the variables, and clear the code of them, to prevent
   * them to be evaluated twice.
   */
  protected def evaluateVariables(variables: Seq[ExprCode]): String = {
    val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
    variables.foreach(_.code = "")
    evaluate
  }

  /**
   * Returns source code to evaluate the variables for required attributes, and clear the code
   * of evaluated variables, to prevent them to be evaluated twice.
   */
  protected def evaluateRequiredVariables(
      attributes: Seq[Attribute],
      variables: Seq[ExprCode],
      required: AttributeSet): String = {
    val evaluateVars = new StringBuilder
    variables.zipWithIndex.foreach { case (ev, i) =>
      if (ev.code != "" && required.contains(attributes(i))) {
        evaluateVars.append(ev.code.trim + "\n")
        ev.code = ""
      }
    }
    evaluateVars.toString()
  }

  /**
   * The subset of inputSet those should be evaluated before this plan.
   *
   * We will use this to insert some code to access those columns that are actually used by current
   * plan before calling doConsume().
   */
  def usedInputs: AttributeSet = references

  /**
   * Generate the Java source code to process the rows from child SparkPlan.
   *
   * This should be override by subclass to support codegen.
   *
   * For example, Filter will generate the code like this:
   *
   *   # code to evaluate the predicate expression, result is isNull1 and value2
   *   if (isNull1 || !value2) continue;
   *   # call consume(), which will call parent.doConsume()
   *
   * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
   */
  def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    throw new UnsupportedOperationException
  }
}


/**
 * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
 *
 * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
 * an RDD iterator of InternalRow.
 *
 * WholeStageCodegen 的叶结点,
 */
case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {

  override def output: Seq[Attribute] = child.output
  override def outputPartitioning: Partitioning = child.outputPartitioning
  override def outputOrdering: Seq[SortOrder] = child.outputOrdering

  override def doExecute(): RDD[InternalRow] = {
    child.execute()
  }

  override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
    child.doExecuteBroadcast()
  }
  //如果该节点不能生成rdd,直接调用父方法的 如
  //project的
  //  override def inputRDDs(): Seq[RDD[InternalRow]] = {
  //    child.asInstanceOf[CodegenSupport].inputRDDs()
  //  }
  override def inputRDDs(): Seq[RDD[InternalRow]] = {
    child.execute() :: Nil
  }

  override def doProduce(ctx: CodegenContext): String = {
    val input = ctx.freshName("input")
    // Right now, InputAdapter is only used when there is one input RDD.
    ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
    val row = ctx.freshName("row")
    s"""
       | while ($input.hasNext()) {
       |   InternalRow $row = (InternalRow) $input.next();
       |   ${consume(ctx, null, row).trim}
       |   if (shouldStop()) return;
       | }
     """.stripMargin
  }

  override def simpleString: String = "INPUT"

  override def treeChildren: Seq[SparkPlan] = Nil
}

object WholeStageCodegen {
  val PIPELINE_DURATION_METRIC = "duration"
}

/**
 * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
 * function.
 *
 * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
 *
 *   WholeStageCodegen       Plan A               FakeInput        Plan B
 * =========================================================================
 *
 * -> execute()
 *     |
 *  doExecute() --------->   inputRDDs() -------> inputRDDs() ------> execute()
 *     |
 *     +----------------->   produce()
 *                             |
 *                          doProduce()  -------> produce()
 *                                                   |
 *                                                doProduce()
 *                                                   |
 *                         doConsume() <--------- consume()
 *                             |
 *  doConsume()  <--------  consume()
 *
 * SparkPlan A should override doProduce() and doConsume().
 *
 * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
 * used to generated code for BoundReference.
 */
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {

  override def output: Seq[Attribute] = child.output
  override def outputPartitioning: Partitioning = child.outputPartitioning
  override def outputOrdering: Seq[SortOrder] = child.outputOrdering

  override private[sql] lazy val metrics = Map(
    "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
      WholeStageCodegen.PIPELINE_DURATION_METRIC))

  /**
   * Generates code for this subtree.
   *
   * @return the tuple of the codegen context and the actual generated source.
   */
  def doCodeGen(): (CodegenContext, String) = {
    val ctx = new CodegenContext
    val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
    val source = s"""
      public Object generate(Object[] references) {
        return new GeneratedIterator(references);
      }

      /** Codegened pipeline for:
       * ${toCommentSafeString(child.treeString.trim)}
       */
      final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {

        private Object[] references;
        ${ctx.declareMutableStates()}

        public GeneratedIterator(Object[] references) {
          this.references = references;
        }

        public void init(int index, scala.collection.Iterator inputs[]) {
          partitionIndex = index;
          ${ctx.initMutableStates()}
        }

        ${ctx.declareAddedFunctions()}

        protected void processNext() throws java.io.IOException {
          ${code.trim}
        }
      }
      """.trim

    // try to compile, helpful for debug
    val cleanedSource = CodeFormatter.stripExtraNewLines(source)
    logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
    CodeGenerator.compile(cleanedSource)
    (ctx, cleanedSource)
  }

  override def doExecute(): RDD[InternalRow] = {
    //代码生成
    val (ctx, cleanedSource) = doCodeGen()
    val references = ctx.references.toArray

    val durationMs = longMetric("pipelineTime")
    //获取到rdd,
    val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
    assert(rdds.size <= 2, "Up to two input RDDs can be supported")
    //如果有一个rdd 调用mapPartitionsWithIndex
    if (rdds.length == 1) {
      rdds.head.mapPartitionsWithIndex { (index, iter) =>
        val clazz = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(index, Array(iter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    } else {
      // Right now, we support up to two input RDDs.
      //两个rdd的分区数相同
      rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
        val partitionIndex = TaskContext.getPartitionId()
        val clazz = CodeGenerator.compile(cleanedSource)
        val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
        buffer.init(partitionIndex, Array(leftIter, rightIter))
        new Iterator[InternalRow] {
          override def hasNext: Boolean = {
            val v = buffer.hasNext
            if (!v) durationMs += buffer.durationMs()
            v
          }
          override def next: InternalRow = buffer.next()
        }
      }
    }
  }

  override def inputRDDs(): Seq[RDD[InternalRow]] = {
    throw new UnsupportedOperationException
  }

  override def doProduce(ctx: CodegenContext): String = {
    throw new UnsupportedOperationException
  }

  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
    val doCopy = if (ctx.copyResult) {
      ".copy()"
    } else {
      ""
    }
    s"""
      |${row.code}
      |append(${row.value}$doCopy);
     """.stripMargin.trim
  }

  override def innerChildren: Seq[SparkPlan] = {
    child :: Nil
  }

  private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
    case InputAdapter(c) => c :: Nil
    case other => other.children.flatMap(collectInputs)
  }

  override def treeChildren: Seq[SparkPlan] = {
    collectInputs(child)
  }

  override def simpleString: String = "WholeStageCodegen"
}


/**
 * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
 * 找到 支持代码生成的chained plans,把他们组合成一个WholeStageCodegen
 */
case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {

  private def supportCodegen(e: Expression): Boolean = e match {
    case e: LeafExpression => true
    case e: CaseWhen => e.shouldCodegen
    // CodegenFallback requires the input to be an InternalRow
    case e: CodegenFallback => false
    case _ => true
  }

  private def numOfNestedFields(dataType: DataType): Int = dataType match {
    case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
    case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
    case a: ArrayType => numOfNestedFields(a.elementType)
    case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
    case _ => 1
  }

  private def supportCodegen(plan: SparkPlan): Boolean = plan match {
    case plan: CodegenSupport if plan.supportCodegen =>
      val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
      // the generated code will be huge if there are too many columns
      val hasTooManyOutputFields =
        numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
      val hasTooManyInputFields =
        plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
      !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
    case _ => false
  }

  /**
   * Inserts a InputAdapter on top of those that do not support codegen.
   */
  private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
    case j @ SortMergeJoin(_, _, _, _, left, right) if j.supportCodegen =>
      // The children of SortMergeJoin should do codegen separately.
      // 左右两边插入 InputAdapter
      j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
        right = InputAdapter(insertWholeStageCodegen(right)))
    case p if !supportCodegen(p) =>
      // collapse them recursively
      //如果结点不支持插入  InputAdapter(WholeStageCodegen(p))
      InputAdapter(insertWholeStageCodegen(p))
    case p =>
      //支持代码生成,最后一个没有孩子,直接生成它自身
      p.withNewChildren(p.children.map(insertInputAdapter))
  }

  /**
   * Inserts a WholeStageCodegen on top of those that support codegen.
   */
  private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
    //如果支持代码生成
    case plan: CodegenSupport if supportCodegen(plan) =>
      WholeStageCodegen(insertInputAdapter(plan))
    case other =>
      //孩子生成
      other.withNewChildren(other.children.map(insertWholeStageCodegen))
  }

  def apply(plan: SparkPlan): SparkPlan = {
    //开启全代码生成
    if (conf.wholeStageEnabled) {
      insertWholeStageCodegen(plan)
    } else {
      plan
    }
  }
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
目录 一:为什么sparkSQL? 3 1.1:sparkSQL的发展历程 3 1.1.1:hive and shark 3 1.1.2:Shark和sparkSQL 4 1.2:sparkSQL的性能 5 1.2.1:内存列存储(In-Memory Columnar Storage) 6 1.2.2:字节码生成技术(bytecode generation,即CG) 6 1.2.3:scala代码优化 7 二:sparkSQL运行架构 8 2.1:Tree和Rule 9 2.1.1:Tree 10 2.1.2:Rule 10 2.2:sqlContext的运行过程 12 2.3:hiveContext的运行过程 14 2.4:catalyst优化器 16 三:sparkSQL组件之解析 17 3.1:LogicalPlan 18 3.2:SqlParser 20 3.1.1:解析过程 20 3.1.2:SqlParser 22 3.1.3:SqlLexical 25 3.1.4:query 26 3.3:Analyzer 26 3.4:Optimizer 28 3.5:SpankPlan 30 四:深入了解sparkSQL运行计划 30 4.1:hive/console安装 30 4.1.1:安装hive/cosole 30 4.1.2:hive/console原理 31 4.2:常用操作 32 4.2.1 查看查询的schema 32 4.2.2 查看查询的整个运行计划 33 4.2.3 查看查询的Unresolved LogicalPlan 33 4.2.4 查看查询的analyzed LogicalPlan 33 4.2.5 查看优化后的LogicalPlan 33 4.2.6 查看物理计划 33 4.2.7 查看RDD的转换过程 33 4.2.8 更多的操作 34 4.3:不同数据源的运行计划 34 4.3.1 json文件 34 4.3.2 parquet文件 35 4.3.3 hive数据 36 4.4:不同查询的运行计划 36 4.4.1 聚合查询 36 4.4.2 join操作 37 4.4.3 Distinct操作 37 4.5:查询的优化 38 4.5.1 CombineFilters 38 4.5.2 PushPredicateThroughProject 39 4.5.3 ConstantFolding 39 4.5.4 自定义优化 39 五:测试环境之搭建 40 5.1:虚拟集群的搭建(hadoop1、hadoop2、hadoop3) 41 5.1.1:hadoop2.2.0集群搭建 41 5.1.2:MySQL的安装 41 5.1.3:hive的安装 41 5.1.4:Spark1.1.0 Standalone集群搭建 42 5.2:客户端的搭建 42 5.3:文件数据准备工作 42 5.4:hive数据准备工作 43 六:sparkSQL之基础应用 43 6.1:sqlContext基础应用 44 6.1.1:RDD 44 6.1.2:parquet文件 46 6.1.3:json文件 46 6.2:hiveContext基础应用 47 6.3:混合使用 49 6.4:缓存之使用 50 6.5:DSL之使用 51 6.6:Tips 51 七:ThriftServer和CLI 51 7.1:令人惊讶的CLI 51 7.1.1 CLI配置 52 7.1.2 CLI命令参数 52 7.1.3 CLI使用 53 7.2:ThriftServer 53 7.2.1 ThriftServer配置 53 7.2.2 ThriftServer命令参数 54 7.2.3 ThriftServer使用 54 7.3:小结 56 八:sparkSQL之综合应用 57 8.1:店铺分类 57 8.2:PageRank 59 8.3:小结 61 九:sparkSQL之调优 61 9.1:并行性 62 9.2: 高效的数据格式 62 9.3:内存的使用 63 9.4:合适的Task 64 9.5:其他的一些建议 64 十:总结 64

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值