SPARK中的wholeStageCodegen全代码生成--GenerateUnsafeProjection.createCode说明

89 篇文章 10 订阅
68 篇文章 0 订阅

背景

对于在在RangeExec中出现的GenerateUnsafeProjection.createCode的方法进行说明

分析

对应的代码为:

  def createCode(
      ctx: CodegenContext,
      expressions: Seq[Expression],
      useSubexprElimination: Boolean = false): ExprCode = {
    val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
    val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))

    val numVarLenFields = exprSchemas.count {
      case Schema(dt, _) => !UnsafeRow.isFixedLength(dt)
      // TODO: consider large decimal and interval type
    }

    val rowWriterClass = classOf[UnsafeRowWriter].getName
    val rowWriter = ctx.addMutableState(rowWriterClass, "rowWriter",
      v => s"$v = new $rowWriterClass(${expressions.length}, ${numVarLenFields * 32});")

    // Evaluate all the subexpression.
    val evalSubexpr = ctx.subexprFunctionsCode

    val writeExpressions = writeExpressionsToBuffer(
      ctx, ctx.INPUT_ROW, exprEvals, exprSchemas, rowWriter, isTopLevel = true)
//   println(s"writeExpressions: $writeExpressions")
    val code =
      code"""
         |$rowWriter.reset();
         |$evalSubexpr
         |$writeExpressions
       """.stripMargin
    // `rowWriter` is declared as a class field, so we can access it directly in methods.
//    println(s"code: $code")
    ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
  }

其中 expressions的值为Seq(BoundReference(0, long, false))
useSubexpreEliminationfalse

  • val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
    这里只是代码生成,exprEvals的值就是range_value_0
    因为useSubexprEliminationfalse,所以不会进行公共代码的消除
  • val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))
    得到对应的表达式的schema
  • val numVarLenFields =
    计算出非固定长度字段的个数,用于初始化UnsafeRowWriter
  • val rowWriter =
    定义并初始化rowWriter,该rowWriter是全局范围的,生成的代码如下:
     private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
     public void init(int index, scala.collection.Iterator[] inputs) {
     ...
     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
     }
    
  • val evalSubexpr = ctx.subexprFunctionsCode
    这里为空字符串
  • val writeExpressions = writeExpressionsToBuffer
    private def writeExpressionsToBuffer(
        ctx: CodegenContext,
        row: String,
        inputs: Seq[ExprCode],
        schemas: Seq[Schema],
        rowWriter: String,
        isTopLevel: Boolean = false): String = {
      val resetWriter = if (isTopLevel) {
        // For top level row writer, it always writes to the beginning of the global buffer holder,
        // which means its fixed-size region always in the same position, so we don't need to call
        // `reset` to set up its fixed-size region every time.
        if (inputs.map(_.isNull).forall(_ == FalseLiteral)) {
          // If all fields are not nullable, which means the null bits never changes, then we don't
          // need to clear it out every time.
          ""
        } else {
          s"$rowWriter.zeroOutNullBytes();"
        }
      } else {
        s"$rowWriter.resetRowWriter();"
      }
    
      val writeFields = inputs.zip(schemas).zipWithIndex.map {
        case ((input, Schema(dataType, nullable)), index) =>
          val dt = UserDefinedType.sqlType(dataType)
    
          val setNull = dt match {
            case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
              // Can't call setNullAt() for DecimalType with precision larger than 18.
              s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});"
            case CalendarIntervalType => s"$rowWriter.write($index, (CalendarInterval) null);"
            case _ => s"$rowWriter.setNullAt($index);"
          }
    
          val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
          if (!nullable) {
            s"""
               |${input.code}
               |${writeField.trim}
             """.stripMargin
          } else {
            s"""
               |${input.code}
               |if (${input.isNull}) {
               |  ${setNull.trim}
               |} else {
               |  ${writeField.trim}
               |}
             """.stripMargin
          }
      }
    
      val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) {
        // TODO: support whole stage codegen
        writeFields.mkString("\n")
      } else {
        assert(row != null, "the input row name cannot be null when generating code to write it.")
        ctx.splitExpressions(
          expressions = writeFields,
          funcName = "writeFields",
          arguments = Seq("InternalRow" -> row))
      }
      s"""
         |$resetWriter
         |$writeFieldsCode
       """.stripMargin
    }
    
    • val resetWriter =
      因为inputs为null为false,所以resetWriter的值为空字符串

    • val writeFields =
      因为inputs的类型是LONG类型,所以对应到val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)代码为:
      case _ => s"$writer.write($index, $input);",所以生成的代码为:

       range_mutableStateArray_0[0].write(0, range_value_0)
      
    • val writeFieldsCode =以及后面的代码组装
      对每一个变量的赋值按照换行符进行分隔。

  • val code =
    组装成ExprCode的code部分,生成的代码如下:
    range_mutableStateArray_0[0].reset();
    
    range_mutableStateArray_0[0].write(0, range_value_0);
    

最后ExprCode的完整部分如下:

 ExprCode(range_mutableStateArray_0[0].reset();
range_mutableStateArray_0[0].write(0, range_value_0);,false,(range_mutableStateArray_0[0].getRow()))
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值