spark WholeStageCodegen代码生成过程详解

spark的执行计划如果实现了CodegenSupport的特质,则可以实现代码的生成。
这里用iceberg表的insert语法跟着例子走一遍。

insert into local.ods.member2
select name, age
from local.ods.member1

生成的执行计划如下所示:

AppendDataExec
  WholeStageCodegenExec
    ProjectExec
      BatchScan
        BatchScanExec

AppendDataExec最终执行的方法是WriteToDataSouceV2Exec的writeWithV2方法,里面会执行val tempRdd = query.execute()也就是select的查询rdd结果。

这里的execute方法会调用WholeStageCodegenExec的doExecute()方法。这个doExecute()里面就是在做代码的生成和执行,最后得到select的查询结果。

具体的代码生成过程调用了doCodeGen()方法,这个过程中会调用子节点ProjectExec的produce方法。

final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
  this.parent = parent
  ctx.freshNamePrefix = variablePrefix
  s"""
     |${ctx.registerComment(s"PRODUCE: ${this.simpleString(SQLConf.get.maxToStringFields)}")}
     |${doProduce(ctx)}
   """.stripMargin
}

这里会调用doProduce方法,最终会调用WholeStageCodegenExec的doProduce方法。doProduce的过程源码中给出了生成聚合代码的例子:

/**
 * Generate the Java source code to process, should be overridden by subclass to support codegen.
 *
 * doProduce() usually generate the framework, for example, aggregation could generate this:
 *
 *   if (!initialized) {
 *     # create a hash map, then build the aggregation hash map
 *     # call child.produce()
 *     initialized = true;
 *   }
 *   while (hashmap.hasNext()) {
 *     row = hashmap.next();
 *     # build the aggregation results
 *     # create variables for results
 *     # call consume(), which will call parent.doConsume()
 *      if (shouldStop()) return;
 *   }
 */

生成的代码大概的流程是:如果没有被初始化,就创建一个hash map,对聚合的hash map进行构建,递归调用子节点的produce过程,并将改节点的初始化置为true。对构建好的hash map进行遍历,构建出聚合结果和结果的变量,并调用consume进行消费,这个过程也会调用父节点的doConsume,判断是否需要停止。
这里生成的是processNext()函数中的核心逻辑。这个例子中整体生成的代码如下所示:

public Object generate(Object[] references) {
  return new GeneratedIteratorForCodegenStage1(references);
}

/*wsc_codegenStageId*/
final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
  private Object[] references;
  private scala.collection.Iterator[] inputs;
  private scala.collection.Iterator inputadapter_input_0;
  private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] project_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];

  public GeneratedIteratorForCodegenStage1(Object[] references) {
    this.references = references;
  }

  public void init(int index, scala.collection.Iterator[] inputs) {
    partitionIndex = index;
    this.inputs = inputs;
    inputadapter_input_0 = inputs[0];
    project_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32);
  }

  protected void processNext() throws java.io.IOException {
    while ( inputadapter_input_0.hasNext()) {
      InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();

      // common sub-expressions

      boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
      UTF8String inputadapter_value_0 = inputadapter_isNull_0 ? null : (inputadapter_row_0.getUTF8String(0));
      boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
      int inputadapter_value_1 = inputadapter_isNull_1 ? -1 : (inputadapter_row_0.getInt(1));
      project_mutableStateArray_0[0].reset();

      project_mutableStateArray_0[0].zeroOutNullBytes();

      if (inputadapter_isNull_0) {
        project_mutableStateArray_0[0].setNullAt(0);
      } else {
        project_mutableStateArray_0[0].write(0, inputadapter_value_0);
      }

      if (inputadapter_isNull_1) {
        project_mutableStateArray_0[0].setNullAt(1);
      } else {
        project_mutableStateArray_0[0].write(1, inputadapter_value_1);
      }
      append((project_mutableStateArray_0[0].getRow()));
      if (shouldStop()) return;
    }
  }
}

生成代码之后,对生成的代码进行编译,如果异常或者编译的代码过长则走常规流程计算子节点的RDD[InternalRow]结果。

val (_, compiledCodeStats) = try {
  CodeGenerator.compile(cleanedSource)
} catch {
  case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback =>
    // We should already saw the error message
    logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
    return child.execute()
}

// Check if compiled code has a too large function
if (compiledCodeStats.maxMethodCodeSize > sqlContext.conf.hugeMethodLimit) {
  logInfo(s"Found too long generated codes and JIT optimization might not work: " +
    s"the bytecode size (${compiledCodeStats.maxMethodCodeSize}) is above the limit " +
    s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
    s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
    s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
  return child.execute()
}

编译通过之后对递归计算子节点的inputRDDs。这里列出主要的调用流程:

val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()

def inputRDDs(): Seq[RDD[InternalRow]]

override def inputRDDs(): Seq[RDD[InternalRow]] = {
  inputRDD :: Nil
}

override def inputRDD: RDD[InternalRow] = child.execute()

因为iceberg中使用的是datasourcev2的表,因此child.execute最终调用的是DataSourceV2ScanExecBase的doExecute方法:

override def doExecute(): RDD[InternalRow] = {
  val numOutputRows = longMetric("numOutputRows")
  inputRDD.map { r =>
    numOutputRows += 1
    r
  }
}

这里的inputRDD是BatchScanExec中的inputRDD:

override lazy val inputRDD: RDD[InternalRow] = {
  new DataSourceRDD(sparkContext, partitions, readerFactory, supportsColumnar)
}

获取到子节点的RDD[InternalRow]后,再次编译生成的代码,前面一次编译的目的是确保编译的成功,这一次编译才会真正的使用。将RDD通过init方法初始化反射生成的类,最后返回整体查询的RDD结果。也就是val tempRdd = query.execute()的结果。以上便是WholeStageCodegen生成代码,并通过生成代码获取查询的RDD结果的过程。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
下面是一个简单的Spark WordCount程序: ```python from pyspark import SparkConf, SparkContext conf = SparkConf().setAppName("WordCount") sc = SparkContext(conf=conf) text_file = sc.textFile("hdfs://localhost:9000/input/sample.txt") words = text_file.flatMap(lambda line: line.split(" ")) word_counts = words.map(lambda word: (word, 1)).reduceByKey(lambda a, b: a + b) word_counts.saveAsTextFile("hdfs://localhost:9000/output/wordcount") ``` 程序的功能是统计一个文本文件中每个单词出现的次数,并将结果保存到HDFS上。 下面是代码的详细解释: 首先,我们导入了`SparkConf`和`SparkContext`。这两个类是Spark的核心类,`SparkConf`用于配置Spark应用程序的属性,`SparkContext`用于连接Spark集群。 ```python from pyspark import SparkConf, SparkContext ``` 接下来,我们创建了一个`SparkConf`实例,并给应用程序命名为`WordCount`。我们还创建了一个`SparkContext`实例,并将`SparkConf`传递给它。这些代码将初始化Spark应用程序并连接到Spark集群。 ```python conf = SparkConf().setAppName("WordCount") sc = SparkContext(conf=conf) ``` 然后,我们使用`textFile()`方法从HDFS中读取输入文件,并创建一个RDD(弹性分布式数据集)。 ```python text_file = sc.textFile("hdfs://localhost:9000/input/sample.txt") ``` 接下来,我们使用`flatMap()`方法将每行文本拆分成单词,并创建一个新的RDD。 ```python words = text_file.flatMap(lambda line: line.split(" ")) ``` 然后,我们使用`map()`方法将每个单词转换为一个`(单词, 1)`的键值对,并创建一个新的RDD。 ```python word_counts = words.map(lambda word: (word, 1)) ``` 接下来,我们使用`reduceByKey()`方法对每个单词的计数进行聚合,并创建一个新的RDD。 ```python word_counts = word_counts.reduceByKey(lambda a, b: a + b) ``` 最后,我们使用`saveAsTextFile()`方法将结果保存到HDFS上,并指定输出目录。 ```python word_counts.saveAsTextFile("hdfs://localhost:9000/output/wordcount") ``` 这就是完整的Spark WordCount程序。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值