spark代码生成
spark sql 逻辑计划 ->物理执行计划->生成 spark rdd 代码 执行
spark plan 与火山模型
火山模型 Volcano
https://zhuanlan.zhihu.com/p/478851521
不断调用next 获取数据处理返回.
spark plan
sparkPlan是物理执行计划,可以直接执行
org.apache.spark.sql.execution.SparkPlan
/**
* Produces the result of the query as an `RDD[InternalRow]`
*
* Overridden by concrete implementations of SparkPlan.
*/
protected def doExecute(): RDD[InternalRow]
例子 filter
org.apache.spark.sql.execution.FilterExec
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
val predicate = Predicate.create(condition, child.output)
predicate.initialize(0)
iter.filter { row =>
val r = predicate.eval(row)
if (r) numOutputRows += 1
r
}
}
}
可以发现也是通过spark rdd来执行的.
例子 range
org.apache.spark.sql.execution.RangeExec
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
if (isEmptyRange) {
new EmptyRDD[InternalRow](sqlContext.sparkContext)
} else {
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
.mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
def getSafeMargin(bi: BigInt): Long =
if (bi.isValidLong) {
bi.toLong
} else if (bi > 0) {
Long.MaxValue
} else {
Long.MinValue
}
val safePartitionStart = getSafeMargin(partitionStart)
val safePartitionEnd = getSafeMargin(partitionEnd)
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
val taskContext = TaskContext.get()
val iter = new Iterator[InternalRow] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
override def hasNext =
if (!overflow) {
if (step > 0) {
number < safePartitionEnd
} else {
number > safePartitionEnd
}
} else false
override def next() = {
val ret = number
number += step
if (number < ret ^ step < 0) {
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
// back, we are pretty sure that we have an overflow.
overflow = true
}
numOutputRows += 1
inputMetrics.incRecordsRead(1)
unsafeRow.setLong(0, ret)
unsafeRow
}
}
new InterruptibleIterator(taskContext, iter)
}
}
}
火山模型的缺陷及全代码生成
Whole-stage code generation was introduced in Spark 2.0 as part of the tungsten engine. And it was inspired by Thomas Newman’s paper; “Efficiently Compiling Efficient Grade Plans For Modern Hardware.”
执行应以数据为中心, 而不是以算子为中心, 应该打破算子间的调用边界, 尽可能将处理的数据保持在寄存器中.
数据不
应该使用 iterator pull 的方式, 而应该 push 到上层算子, 以获得更好的代码/数据局部性.
1.火山模型不断调用 next ,有过多的虚函数调用.
2.虚函数的调用 导致数据无法缓存在寄存器中.
1~100 对 基数求和
火山模型:
Array.filter(x%2==1).sum()
代码生成
int sum=0;
for(int i=0;i<100;i++){
if(i%2==1){
sum+=i;
}
}
具体实现
https://issues.apache.org/jira/browse/SPARK-12795
打开这个issue.可以看到具体的代码全代码生成是由一个人来完成的.
Benchmark
maven
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.12</artifactId>
<version>3.1.2</version>
</dependency>
Benchmark
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sparkbenchmark
import java.io.{OutputStream, PrintStream}
import org.apache.commons.io.output.TeeOutputStream
import org.apache.commons.lang3.SystemUtils
import org.apache.spark.util.Utils
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.util.Try
/**
* Utility class to benchmark components. An example of how to use this is:
* val benchmark = new Benchmark("My Benchmark", valuesPerIteration)
* benchmark.addCase("V1")(<function>)
* benchmark.addCase("V2")(<function>)
* benchmark.run
* This will output the average time to run each function and the rate of each function.
*
* The benchmark function takes one argument that is the iteration that's being run.
*
* @param name name of this benchmark.
* @param valuesPerIteration number of values used in the test case, used to compute rows/s.
* @param minNumIters the min number of iterations that will be run per case, not counting warm-up.
* @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up.
* @param minTime further iterations will be run for each case until this time is used up.
* @param outputPerIteration if true, the timing for each run will be printed to stdout.
* @param output optional output stream to write benchmark results to
*/
class Benchmark(
name: String,
valuesPerIteration: Long,
minNumIters: Int = 2,
warmupTime: FiniteDuration = 2.seconds,
minTime: FiniteDuration = 2.seconds,
outputPerIteration: Boolean = false,
output: Option[OutputStream] = None) {
import Benchmark._
val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
val out = if (output.isDefined) {
new PrintStream(new TeeOutputStream(System.out, output.get))
} else {
System.out
}
/**
* Adds a case to run when run() is called. The given function will be run for several
* iterations to collect timing statistics.
*
* @param name of the benchmark case
* @param numIters if non-zero, forces exactly this many iterations to be run
*/
def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = {
addTimerCase(name, numIters) { timer =>
timer.startTiming()
f(timer.iteration)
timer.stopTiming()
}
}
/**
* Adds a case with manual timing control. When the function is run, timing does not start
* until timer.startTiming() is called within the given function. The corresponding
* timer.stopTiming() method must be called before the function returns.
*
* @param name of the benchmark case
* @param numIters if non-zero, forces exactly this many iterations to be run
*/
def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = {
benchmarks += Benchmark.Case(name, f, numIters)
}
/**
* Runs the benchmark and outputs the results to stdout. This should be copied and added as
* a comment with the benchmark. Although the results vary from machine to machine, it should
* provide some baseline.
*/
def run(): Unit = {
require(benchmarks.nonEmpty)
// scalastyle:off
println("Running benchmark: " + name)
val results = benchmarks.map { c =>
println(" Running case: " + c.name)
measure(valuesPerIteration, c.numIters)(c.fn)
}
println
val firstBest = results.head.bestMs
// The results are going to be processor specific so it is useful to include that.
out.println(Benchmark.getJVMOSInfo())
out.println(Benchmark.getProcessorName())
val nameLen = Math.max(40, Math.max(name.length, benchmarks.map(_.name.length).max))
out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", "Per Row(ns)", "Relative")
out.println("-" * (nameLen + 80))
results.zip(benchmarks).foreach { case (result, benchmark) =>
out.printf(s"%-${nameLen}s %14s %14s %11s %12s %13s %10s\n",
benchmark.name,
"%5.0f" format result.bestMs,
"%4.0f" format result.avgMs,
"%5.0f" format result.stdevMs,
"%10.1f" format result.bestRate,
"%6.1f" format (1000 / result.bestRate),
"%3.1fX" format (firstBest / result.bestMs))
}
out.println
// scalastyle:on
}
/**
* Runs a single function `f` for iters, returning the average time the function took and
* the rate of the function.
*/
def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = {
System.gc() // ensures garbage from previous cases don't impact this one
val warmupDeadline = warmupTime.fromNow
while (!warmupDeadline.isOverdue) {
f(new Benchmark.Timer(-1))
}
val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters
val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos
val runTimes = ArrayBuffer[Long]()
var totalTime = 0L
var i = 0
while (i < minIters || totalTime < minDuration) {
val timer = new Benchmark.Timer(i)
f(timer)
val runTime = timer.totalTime()
runTimes += runTime
totalTime += runTime
if (outputPerIteration) {
// scalastyle:off
println(s"Iteration $i took ${NANOSECONDS.toMicros(runTime)} microseconds")
// scalastyle:on
}
i += 1
}
// scalastyle:off
println(s" Stopped after $i iterations, ${NANOSECONDS.toMillis(runTimes.sum)} ms")
// scalastyle:on
assert(runTimes.nonEmpty)
val best = runTimes.min
val avg = runTimes.sum / runTimes.size
val stdev = if (runTimes.size > 1) {
math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum / (runTimes.size - 1))
} else 0
Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0, stdev / 1000000.0)
}
}
object Benchmark {
/**
* Object available to benchmark code to control timing e.g. to exclude set-up time.
*
* @param iteration specifies this is the nth iteration of running the benchmark case
*/
class Timer(val iteration: Int) {
private var accumulatedTime: Long = 0L
private var timeStart: Long = 0L
def startTiming(): Unit = {
assert(timeStart == 0L, "Already started timing.")
timeStart = System.nanoTime
}
def stopTiming(): Unit = {
assert(timeStart != 0L, "Have not started timing.")
accumulatedTime += System.nanoTime - timeStart
timeStart = 0L
}
def totalTime(): Long = {
assert(timeStart == 0L, "Have not stopped timing.")
accumulatedTime
}
}
case class Case(name: String, fn: Timer => Unit, numIters: Int)
case class Result(avgMs: Double, bestRate: Double, bestMs: Double, stdevMs: Double)
/**
* This should return a user helpful processor information. Getting at this depends on the OS.
* This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz"
*/
def getProcessorName(): String = {
// val cpu = if (SystemUtils.IS_OS_MAC_OSX) {
// Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string"))
// .stripLineEnd
// } else if (SystemUtils.IS_OS_LINUX) {
// Try {
// val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")).stripLineEnd
// Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo"))
// .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "")
// }.getOrElse("Unknown processor")
// } else {
// System.getenv("PROCESSOR_IDENTIFIER")
// }
"cpu"
}
/**
* This should return a user helpful JVM & OS information.
* This should return something like
* "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64"
*/
def getJVMOSInfo(): String = {
val vmName = System.getProperty("java.vm.name")
val runtimeVersion = System.getProperty("java.runtime.version")
val osName = System.getProperty("os.name")
val osVersion = System.getProperty("os.version")
s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}"
}
}
Test
object Testxx {
def testWholeStage(values: Int): Unit = {
val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
val sc = SparkContext.getOrCreate(conf)
val sqlContext = SQLContext.getOrCreate(sc)
val benchmark = new Benchmark("Single Int Column Scan", values)
benchmark.addCase("Without whole stage codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).filter("(id & 1) = 1").count()
}
benchmark.addCase("With whole stage codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).filter("(id & 1) = 1").count()
}
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------
Without whole stage codegen 6725.52 31.18 1.00 X
With whole stage codegen 2233.05 93.91 3.01 X
*/
benchmark.run()
}
def main(args: Array[String]): Unit = {
testWholeStage(1024 * 1024 * 200)
}
}
控制参数
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
具体代码分析
切换分支到
SPARK-14722 对应的版本
spark 代码生成 pull request
https://github.com/apache/spark/pull/10735
https://www.databricks.com/session_na20/understanding-and-improving-code-generation
A SparkPlan that support codegen need to implement doProduce() and doConsume():
def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
case _: BroadcastHashJoin => "bhj"
case _: SortMergeJoin => "smj"
case _: PhysicalRDD => "rdd"
case _: DataSourceScan => "scan"
case _ => nodeName.toLowerCase
}
/**
* Creates a metric using the specified name.
*
* @return name of the variable representing the metric
*/
def metricTerm(ctx: CodegenContext, name: String): String = {
val metric = ctx.addReferenceObj(name, longMetric(name))
val value = ctx.freshName("metricValue")
val cls = classOf[LongSQLMetricValue].getName
ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
value
}
/**
* Whether this SparkPlan support whole stage codegen or not.
*/
def supportCodegen: Boolean = true
/**
* Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
*/
protected var parent: CodegenSupport = null
/**
* Returns all the RDDs of InternalRow which generates the input rows.
*
* Note: right now we support up to two RDDs.
*/
def inputRDDs(): Seq[RDD[InternalRow]]
/**
* Returns Java source code to process the rows from input RDD.
*/
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
s"""
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
|${doProduce(ctx)}
""".stripMargin
}
/**
* Generate the Java source code to process, should be overridden by subclass to support codegen.
*
* doProduce() usually generate the framework, for example, aggregation could generate this:
*
* if (!initialized) {
* # create a hash map, then build the aggregation hash map
* # call child.produce()
* initialized = true;
* }
* while (hashmap.hasNext()) {
* row = hashmap.next();
* # build the aggregation results
* # create variables for results
* # call consume(), which will call parent.doConsume()
* if (shouldStop()) return;
* }
*/
protected def doProduce(ctx: CodegenContext): String
/**
* Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
*/
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
val inputVars =
if (row != null) {
ctx.currentVars = null
ctx.INPUT_ROW = row
output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
}
} else {
assert(outputVars != null)
assert(outputVars.length == output.length)
// outputVars will be used to generate the code for UnsafeRow, so we should copy them
outputVars.map(_.copy())
}
val rowVar = if (row != null) {
ExprCode("", "false", row)
} else {
if (outputVars.nonEmpty) {
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
val evaluateInputs = evaluateVariables(outputVars)
// generate the code to create a UnsafeRow
ctx.currentVars = outputVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
val code = s"""
|$evaluateInputs
|${ev.code.trim}
""".stripMargin.trim
ExprCode(code, "false", ev.value)
} else {
// There is no columns
ExprCode("", "false", "unsafeRow")
}
}
ctx.freshNamePrefix = parent.variablePrefix
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
|$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
/**
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
evaluate
}
/**
* Returns source code to evaluate the variables for required attributes, and clear the code
* of evaluated variables, to prevent them to be evaluated twice.
*/
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
evaluateVars.toString()
}
/**
* The subset of inputSet those should be evaluated before this plan.
*
* We will use this to insert some code to access those columns that are actually used by current
* plan before calling doConsume().
*/
def usedInputs: AttributeSet = references
/**
* Generate the Java source code to process the rows from child SparkPlan.
*
* This should be override by subclass to support codegen.
*
* For example, Filter will generate the code like this:
*
* # code to evaluate the predicate expression, result is isNull1 and value2
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
*
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
*/
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
}
/**
* InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
*
* This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
* an RDD iterator of InternalRow.
*
* WholeStageCodegen 的叶结点,
*/
case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doExecute(): RDD[InternalRow] = {
child.execute()
}
override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
child.doExecuteBroadcast()
}
//如果该节点不能生成rdd,直接调用父方法的 如
//project的
// override def inputRDDs(): Seq[RDD[InternalRow]] = {
// child.asInstanceOf[CodegenSupport].inputRDDs()
// }
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
override def doProduce(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// Right now, InputAdapter is only used when there is one input RDD.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val row = ctx.freshName("row")
s"""
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
| }
""".stripMargin
}
override def simpleString: String = "INPUT"
override def treeChildren: Seq[SparkPlan] = Nil
}
object WholeStageCodegen {
val PIPELINE_DURATION_METRIC = "duration"
}
/**
* WholeStageCodegen compile a subtree of plans that support codegen together into single Java
* function.
*
* Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
*
* WholeStageCodegen Plan A FakeInput Plan B
* =========================================================================
*
* -> execute()
* |
* doExecute() ---------> inputRDDs() -------> inputRDDs() ------> execute()
* |
* +-----------------> produce()
* |
* doProduce() -------> produce()
* |
* doProduce()
* |
* doConsume() <--------- consume()
* |
* doConsume() <-------- consume()
*
* SparkPlan A should override doProduce() and doConsume().
*
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference.
*/
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override private[sql] lazy val metrics = Map(
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegen.PIPELINE_DURATION_METRIC))
/**
* Generates code for this subtree.
*
* @return the tuple of the codegen context and the actual generated source.
*/
def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
/** Codegened pipeline for:
* ${toCommentSafeString(child.treeString.trim)}
*/
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
${ctx.declareMutableStates()}
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator inputs[]) {
partitionIndex = index;
${ctx.initMutableStates()}
}
${ctx.declareAddedFunctions()}
protected void processNext() throws java.io.IOException {
${code.trim}
}
}
""".trim
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
(ctx, cleanedSource)
}
override def doExecute(): RDD[InternalRow] = {
//代码生成
val (ctx, cleanedSource) = doCodeGen()
val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
//获取到rdd,
val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
assert(rdds.size <= 2, "Up to two input RDDs can be supported")
//如果有一个rdd 调用mapPartitionsWithIndex
if (rdds.length == 1) {
rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
} else {
// Right now, we support up to two input RDDs.
//两个rdd的分区数相同
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
val partitionIndex = TaskContext.getPartitionId()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(partitionIndex, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
}
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
throw new UnsupportedOperationException
}
override def doProduce(ctx: CodegenContext): String = {
throw new UnsupportedOperationException
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val doCopy = if (ctx.copyResult) {
".copy()"
} else {
""
}
s"""
|${row.code}
|append(${row.value}$doCopy);
""".stripMargin.trim
}
override def innerChildren: Seq[SparkPlan] = {
child :: Nil
}
private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
case InputAdapter(c) => c :: Nil
case other => other.children.flatMap(collectInputs)
}
override def treeChildren: Seq[SparkPlan] = {
collectInputs(child)
}
override def simpleString: String = "WholeStageCodegen"
}
/**
* Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
* 找到 支持代码生成的chained plans,把他们组合成一个WholeStageCodegen
*/
case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
}
private def numOfNestedFields(dataType: DataType): Int = dataType match {
case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
case a: ArrayType => numOfNestedFields(a.elementType)
case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
case _ => 1
}
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val hasTooManyOutputFields =
numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
val hasTooManyInputFields =
plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
!willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
case _ => false
}
/**
* Inserts a InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
case j @ SortMergeJoin(_, _, _, _, left, right) if j.supportCodegen =>
// The children of SortMergeJoin should do codegen separately.
// 左右两边插入 InputAdapter
j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
right = InputAdapter(insertWholeStageCodegen(right)))
case p if !supportCodegen(p) =>
// collapse them recursively
//如果结点不支持插入 InputAdapter(WholeStageCodegen(p))
InputAdapter(insertWholeStageCodegen(p))
case p =>
//支持代码生成,最后一个没有孩子,直接生成它自身
p.withNewChildren(p.children.map(insertInputAdapter))
}
/**
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
//如果支持代码生成
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegen(insertInputAdapter(plan))
case other =>
//孩子生成
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}
def apply(plan: SparkPlan): SparkPlan = {
//开启全代码生成
if (conf.wholeStageEnabled) {
insertWholeStageCodegen(plan)
} else {
plan
}
}
}