spark-sql-dataset源码解析(1)

spark的sql/dataset的基本操作

showstring方法

  private[sql] def showString(
      _numRows: Int,
      truncate: Int = 20,
      vertical: Boolean = false): String = {
    val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)
    // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
    val tmpRows = getRows(numRows, truncate)

    val hasMoreData = tmpRows.length - 1 > numRows
    val rows = tmpRows.take(numRows + 1)

    val sb = new StringBuilder
    val numCols = schema.fieldNames.length
    // We set a minimum column width at '3'
    val minimumColWidth = 3

    if (!vertical) {
      // Initialise the width of each column to a minimum value
      val colWidths = Array.fill(numCols)(minimumColWidth)

      // Compute the width of each column
      for (row <- rows) {
        for ((cell, i) <- row.zipWithIndex) {
          colWidths(i) = math.max(colWidths(i), Utils.stringHalfWidth(cell))
        }
      }

      val paddedRows = rows.map { row =>
        row.zipWithIndex.map { case (cell, i) =>
          if (truncate > 0) {
            StringUtils.leftPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length)
          } else {
            StringUtils.rightPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length)
          }
        }
      }

      // Create SeparateLine
      val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()

      // column names
      paddedRows.head.addString(sb, "|", "|", "|\n")
      sb.append(sep)

      // data
      paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
      sb.append(sep)
    } else {
      // Extended display mode enabled
      val fieldNames = rows.head
      val dataRows = rows.tail

      // Compute the width of field name and data columns
      val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) =>
        math.max(curMax, Utils.stringHalfWidth(fieldName))
      }
      val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) =>
        math.max(curMax, row.map(cell => Utils.stringHalfWidth(cell)).max)
      }

      dataRows.zipWithIndex.foreach { case (row, i) =>
        // "+ 5" in size means a character length except for padded names and data
        val rowHeader = StringUtils.rightPad(
          s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-")
        sb.append(rowHeader).append("\n")
        row.zipWithIndex.map { case (cell, j) =>
          val fieldName = StringUtils.rightPad(fieldNames(j),
            fieldNameColWidth - Utils.stringHalfWidth(fieldNames(j)) + fieldNames(j).length)
          val data = StringUtils.rightPad(cell,
            dataColWidth - Utils.stringHalfWidth(cell) + cell.length)
          s" $fieldName | $data "
        }.addString(sb, "", "\n", "\n")
      }
    }

    // Print a footer
    if (vertical && rows.tail.isEmpty) {
      // In a vertical mode, print an empty row set explicitly
      sb.append("(0 rows)\n")
    } else if (hasMoreData) {
      // For Data that has more than "numRows" records
      val rowsString = if (numRows == 1) "row" else "rows"
      sb.append(s"only showing top $numRows $rowsString\n")
    }

    sb.toString()
  }

这段代码是Scala中Spark SQL中的showString方法,用于将DataFrame或Dataset的内容以表格形式展示出来。
该方法包含三个参数:_numRows表示要展示的记录数量,truncate表示每个单元格的最大字符数(如果超过,则截断显示),vertical表示是否以垂直模式展示。如果是水平模式,就输出表头和数据的每行;如果是垂直模式,就以扩展显示模式输出每列的字段名和数据值。最后,添加页脚信息,提示是否还有更多的数据。

val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)

确保numRows不超过限制,大于0小于ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1
getRows,主要用于获取数据表格中要展示的数据行信息(行列式样)。
val hasMoreData = tmpRows.length - 1 > numRows: 这一行代码计算是否还有更多的数据需要被输出,通过比较需要输出的行数 numRows 和字符串列表 tmpRows 长度减1(因为页码从0开始)的大小关系来确定。如果 tmpRows 长度减1 大于 numRows,则说明还有数据未被输出,设置变量 hasMoreData 为 true,否则为 false。

val rows = tmpRows.take(numRows + 1): 这一行代码根据指定的每页输出行数 numRows 将字符串列表 tmpRows 进行切分,取出前 numRows + 1 行字符串,赋值给变量 rows,其中加1 是为了保证最后一页也能包含 numRows 行字符串(如果不够 numRows 行,就全部输出)。接下来,showstring() 函数会将变量 rows 中的字符串连接起来形成一个新的字符串,并返回给调用者。

getRows

private[sql] def getRows(
      numRows: Int,
      truncate: Int): Seq[Seq[String]] = {
    val newDf = logicalPlan match {
      case c: CommandResult =>
        // Convert to `LocalRelation` and let `ConvertToLocalRelation` do the casting locally to
        // avoid triggering a job
        Dataset.ofRows(sparkSession, LocalRelation(c.output, c.rows))
      case _ => toDF()
    }
    val castCols = newDf.logicalPlan.output.map { col =>
      Column(ToPrettyString(col))
    }
    val data = newDf.select(castCols: _*).take(numRows + 1)

    // For array values, replace Seq and Array with square brackets
    // For cells that are beyond `truncate` characters, replace it with the
    // first `truncate-3` and "..."
    schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row =>
      row.toSeq.map { cell =>
        assert(cell != null, "ToPrettyString is not nullable and should not return null value")
        // Escapes meta-characters not to break the `showString` format
        val str = SchemaUtils.escapeMetaCharacters(cell.toString)
        if (truncate > 0 && str.length > truncate) {
          // do not show ellipses for strings shorter than 4 characters.
          if (truncate < 4) str.substring(0, truncate)
          else str.substring(0, truncate - 3) + "..."
        } else {
          str
        }
      }: Seq[String]
    }
  }

这是一个私有方法 getRows,主要用于获取数据表格中要展示的数据行信息(行列式样)。
需要判断logicalPlan对象是不是CommandResult,如果是,将Dataset.ofRows(sparkSession, LocalRelation(c.output, c.rows))结果付给c。
CommandResult表示一些轻量级的sql命令的执行结果。

利用 newDf 的逻辑计划的输出,将每个列名映射到一个 Column 对象,然后取出第一行数据并对其值进行类型转换,以便后续可以格式化输出。

对新的 Dataset 对象执行 select 操作,选取之前构造的 Column 列表中的所有列,并获取前 numRows + 1 行结果数据。此处需要额外获取一行数据,是为了判断是否需要截断显示,即当数据量大于 numRows 时,需要在末尾添加省略号。

最后,对于每个单元格的字符串值,如果其长度超过截断值,则将其截断为指定长度,并添加省略号。如果字符串值包含元字符,则需要进行转义,以避免输出格式的破坏。

下面看ofRows方法

def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
  sparkSession.withActive {
    val qe = sparkSession.sessionState.executePlan(logicalPlan)
    qe.assertAnalyzed()
    new Dataset[Row](qe, RowEncoder(qe.analyzed.schema))
  }

ofRows 方法,用于将一个逻辑执行计划(LogicalPlan)转化为 Dataset。qe表示查询的执行计划,用assertAnalyzed()验证逻辑计划是否已经被分析和优化,然后根据执行计划的元数据(schema),构建一个 RowEncoder 对象,并将其作为参数传入 Dataset 的构造函数中,创建一个新的 Dataset[Row] 对象,并返回。qe对象是一个查询执行器,包含了一个查询的所有上下文信息,包括逻辑计划(Logical Plan)、物理计划(Physical Plan)、数据源等。qe.analyzed.schema 表示查询语句经过解析和优化后所得到的元数据(即Schema),它是一个 StructType 类型的对象。可以通过 qe.analyzed.schema 获取查询结果的列信息、列名、数据类型,以及是否可空等信息。

select()

  def select(cols: Column*): DataFrame = withPlan {
    val untypedCols = cols.map {
      case typedCol: TypedColumn[_, _] =>
        // Checks if a `TypedColumn` has been inserted with
        // specific input type and schema by `withInputType`.
        val needInputType = typedCol.expr.exists {
          case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true
          case _ => false
        }

        if (!needInputType) {
          typedCol
        } else {
          throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString)
        }

      case other => other
    }
    Project(untypedCols.map(_.named), logicalPlan)
  }

首先对传入的所有列进行类型校验。如果是一个类型化的列(即 TypedColumn 对象),则需要进一步判断其是否需要特定的输入类型或 schema。如果不需要,就直接使用当前列;否则,抛出异常,提示用户不能在非类型化的 select 操作中使用类型化的列。

Project 是用于选择或重命名列的逻辑计划节点。该节点将输入数据集中的某些列进行选择或重命名,并输出到下游逻辑计划节点或物理计划节点中。当执行查询语句时,Spark SQL 会自动将查询中的 SELECT 子句转换为一个或多个 Project 节点。最终返回dataframe。

project节点属性:

  • projectList:类型为 Seq[NamedExpression],表示投影列的列表,每个元素是一个命名表达式
  • child:类型为 LogicalPlan,表示该 Project 节点的子节点
    方法:
output:返回一个 Seq[Attribute],表示该节点的输出列。由于 Project 节点本质上是一个投影操作,因此它的输出列就等于其输入列(即子节点的输出列)中指定的列子集。
outputExpressions:返回一个 Seq[NamedExpression],表示该节点的输出列,每个元素是一个命名表达式。和 output 方法的区别是,outputExpressions 方法返回的命名表达式保留了原始名称和别名,而 output 方法只返回原始名称。
maxRows:返回一个 Option[Long],表示子节点最多能输出的行数。这个值通常由执行引擎根据查询计划的特性进行优化,比如加入了限制条件或者 Top N 操作,等等。
maxRowsPerPartition:返回一个 Option[Long],表示子节点每个分区最多能输出的行数。和 maxRows 方法类似,只是作用于每个分区。
nodePatterns:返回一个 Seq[TreePattern],表示该节点匹配的树模式。这个属性通常是用于优化查询计划的,比如在查询重写阶段自动应用某些规则。
resolved:返回一个 Boolean,表示该节点是否已经解析完成。在 Spark SQL 中,逻辑计划的节点通常需要经过三个阶段:未解析、部分解析和完全解析。只有在完全解析之后,计划才能被转换为物理计划,进一步执行。
validConstraints:返回一个 ExpressionSet,表示该节点可以应用的有效约束条件(即谓词)。这个属性通常是用于优化查询计划的,比如在查询重写阶段自动推断出某些额外的谓词并应用。
metadataOutput:返回一个 Seq[Attribute],表示此节点隱藏列的元數據输出。这个属性通常是用于存储数据集的统计信息,比如行数、均值等,也可以用来标记特殊的标识符和属性。
withNewChildInternal:返回一个新的 Project 节点,其中子节点替换为给定的逻辑计划。这个方法通常由查询优化器调用,在执行计划转换时更新节点信息。

Project执行操作时,需要转化为物理计划节点ProjectExec进行操作。具体执行时,在 ProjectExec 类中,会重写 doExecute() 方法,该方法用于实现 Project 节点的物理计算操作,即对输入数据集进行投影。在 doExecute() 方法中,会先通过调用 child.execute() 方法获取输入数据集,然后将输入数据集按照投影列表指定的列进行投影,并返回投影后的数据集作为输出。注意这里的 child 实际上就是 Project 节点对应的子节点,即要进行投影计算的输入数据集。

最终,doExecute() 方法会将投影后的数据集返回给 ProjectExec 的父节点,并由它的父节点持续传递,直到最终输出到 Spark 应用程序或存储系统中。

下面是代码 Project(untypedCols.map(_.named), logicalPlan) 的详细解释:

  1. 创建 Project 节点。在这段代码中,使用 Project 构造函数创建了一个 Project 节点。该节点的第一个参数是一个新的列列表,表示需要对哪些列进行选择或重命名;第二个参数是下游逻辑计划节点或物理计划节点,表示输出的数据将传递给哪个节点进行处理;
  2. 映射列名。在这段代码中,使用了 untypedCols.map(.named) 将 untypedCols 中的列转换为新的列列表,并将它们的列名称作为新列列表的列名。其中,.named 表示一个匿名函数,用于将 Column 对象转换为新的 NamedExpression 对象;
  3. 输出选择或重命名后的列。最终,Project 节点会将选择或重命名后的列输出到下游逻辑计划节点节点中。在执行查询时,Spark SQL 会自动将查询中的各种函数、聚合操作等转换为一个或多个 Project 节点,以便对数据集进行选择、过滤、计算等操作。

untypedCols 变量是一个未经类型化的列列表,如果在执行查询时出现了类型错误,则会导致查询失败

因此,select 函数的执行过程大致如下:

  1. 首先对输入的列进行类型校验,如果发现有类型化的列需要特定的输入类型或 schema,就抛出异常;
  2. 然后将选择后的列信息和当前 DataFrame 逻辑计划传递给 Project 构造函数,创建一个新的 Project 节点;
  3. Project 节点表示对数据集中的指定列进行选择或重命名,并输出到下游逻辑或物理计划节点中;
  4. 根据整个查询计划生成物理执行计划,并执行查询操作,返回结果数据集。

下一章我们继续分析filter方法

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值