Spark 之 logical plan

Cast

Cast 强制类型转换发生在 Logical Plan 转成 Analyzed Logical Plan阶段,

根据表达式 override def inputTypes() 方法进行校验,然后
childrenResolved 最终和 inputTypes 进行校验

    override protected def coerceTypes(
        plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
      // Skip nodes who's children have not been resolved yet.
      case e if !e.childrenResolved => e


  /**
   * Casts types according to the expected input types for [[Expression]]s.
   */
  object ImplicitTypeCasts extends TypeCoercionRule {
  ...
  /**
     * Given an expected data type, try to cast the expression and return the cast expression.
     *
     * If the expression already fits the input type, we simply return the expression itself.
     * If the expression has an incompatible type that cannot be implicitly cast, return None.
     */
    def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
      implicitCast(e.dataType, expectedType).map { dt =>
        if (dt == e.dataType) e else Cast(e, dt)
      }
    }
private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = {
      // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
      // We wrap immediately an Option after this.
      @Nullable val ret: DataType = (inType, expectedType) match {
        // If the expected type is already a parent of the input type, no need to cast.
        case _ if expectedType.acceptsType(inType) => inType

        // Cast null type (usually from null literals) into target types
        case (NullType, target) => target.defaultConcreteType

        // If the function accepts any numeric type and the input is a string, we follow the hive
        // convention and cast that input into a double
        case (StringType, NumericType) => NumericType.defaultConcreteType

        // Implicit cast among numeric types. When we reach here, input type is not acceptable.

        // If input is a numeric type but not decimal, and we expect a decimal type,
        // cast the input to decimal.
        case (d: NumericType, DecimalType) => DecimalType.forType(d)
        // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
        case (_: NumericType, target: NumericType) => target

        // Implicit cast between date time types
        case (DateType, TimestampType) => TimestampType
        case (TimestampType, DateType) => DateType

        // Implicit cast from/to string
        case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
        case (StringType, target: NumericType) => target
        case (StringType, DateType) => DateType
        case (StringType, TimestampType) => TimestampType
        case (StringType, BinaryType) => BinaryType
        // Cast any atomic type to string.
        case (any: AtomicType, StringType) if any != StringType => StringType

        // When we reach here, input type is not acceptable for any types in this type collection,
        // try to find the first one we can implicitly cast.
        case (_, TypeCollection(types)) =>
          types.flatMap(implicitCast(inType, _)).headOption.orNull

        // Implicit cast between array types.
        //
        // Compare the nullabilities of the from type and the to type, check whether the cast of
        // the nullability is resolvable by the following rules:
        // 1. If the nullability of the to type is true, the cast is always allowed;
        // 2. If the nullability of the to type is false, and the nullability of the from type is
        // true, the cast is never allowed;
        // 3. If the nullabilities of both the from type and the to type are false, the cast is
        // allowed only when Cast.forceNullable(fromType, toType) is false.
        case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
          implicitCast(fromType, toType).map(ArrayType(_, true)).orNull

        case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null

        case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
            if !Cast.forceNullable(fromType, toType) =>
          implicitCast(fromType, toType).map(ArrayType(_, false)).orNull

        // Implicit cast between Map types.
        // Follows the same semantics of implicit casting between two array types.
        // Refer to documentation above. Make sure that both key and values
        // can not be null after the implicit cast operation by calling forceNullable
        // method.
        case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn))
            if !Cast.forceNullable(fromKeyType, toKeyType) && Cast.resolvableNullability(fn, tn) =>
          if (Cast.forceNullable(fromValueType, toValueType) && !tn) {
            null
          } else {
            val newKeyType = implicitCast(fromKeyType, toKeyType).orNull
            val newValueType = implicitCast(fromValueType, toValueType).orNull
            if (newKeyType != null && newValueType != null) {
              MapType(newKeyType, newValueType, tn)
            } else {
              null
            }
          }

        case _ => null
      }
      Option(ret)
    }
  ...
  }
Expression 的 inputTypes 校验机制
override def checkInputDataTypes(): TypeCheckResult = {
    ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
  }
object ExpectsInputTypes {

  def checkInputDataTypes(
      inputs: Seq[Expression],
      inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
    val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
      case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
        s"argument ${idx + 1} requires ${expected.simpleString} type, " +
          s"however, '${input.sql}' is of ${input.dataType.catalogString} type."
    }

    if (mismatches.isEmpty) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
    }
  }
}

这里以ShiftLeft 举例

/**
 * Bitwise left shift.
 *
 * @param left the base number to shift.
 * @param right number of bits to left shift.
 */
@ExpressionDescription(
  usage = "_FUNC_(base, expr) - Bitwise left shift.",
  examples = """
    Examples:
      > SELECT _FUNC_(2, 1);
       4
  """,
  since = "1.5.0")
case class ShiftLeft(left: Expression, right: Expression)
  extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {

  override def inputTypes: Seq[AbstractDataType] =
    Seq(TypeCollection(IntegerType, LongType), IntegerType)

  override def dataType: DataType = left.dataType

  protected override def nullSafeEval(input1: Any, input2: Any): Any = {
    input1 match {
      case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
      case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
  }
}
zip 用法

zip为按顺序一一对应


scala> val numbers = Seq(0, 1, 2, 3, 4)
numbers: Seq[Int] = List(0, 1, 2, 3, 4)

scala> val series = Seq(10, 11, 12, 13, 14)
series: Seq[Int] = List(10, 11, 12, 13, 14)

scala> numbers zip series
res0: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))

scala> numbers.zip(series)
res1: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))

如果某一个集合多余,会去掉多余的
比如:

scala> val series = Seq(10, 11, 12, 13, 14, 15)
series: Seq[Int] = List(10, 11, 12, 13, 14, 15)

scala> numbers zip series
res2: Seq[(Int, Int)] = List((0,10), (1,11), (2,12), (3,13), (4,14))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值