SparkUserDefinedFunction源码分析
原理
用户自定义函数(UDF)是一种允许用户在 Spark SQL 中定义自己的函数并应用于 DataFrame 的功能。UDF 可以接受一个或多个输入参数,并生成一个输出结果。UDF 的目的是扩展 Spark SQL 的功能,使用户能够使用自定义逻辑对数据进行处理和转换。
在 Apache Spark 中,UDF 的原理是通过创建 UserDefinedFunction
对象来表示用户定义的函数,并将其应用于 DataFrame 的列。UserDefinedFunction
类封装了用户定义的函数对象、返回值类型和输入参数类型等信息。它提供了方法来配置 UDF 的属性,例如是否可空、是否确定性等。通过调用 apply
方法,可以将 UserDefinedFunction
应用到 DataFrame 的列上,从而得到一个新的 Column 对象,该对象包含了应用 UDF 后的结果。
示例1
下面是一个使用用户自定义函数的示例:
object udfTest1 {
// 定义一个自定义函数,将字符串转换为大写
val toUpper = udf((s: String) => s.toUpperCase)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder
.master("local[2]")
.appName("appName").config("", true)
.getOrCreate()
import spark.implicits._
// 创建一个 DataFrame
val df = spark.sparkContext.parallelize(Seq(("apple"), ("orange"), ("banana")))
.toDF("fruit")
// 使用自定义函数将 fruit 列中的字符串转换为大写
val result = df.select(toUpper(col("fruit")).as("uppercase_fruit"))
// 显示结果
result.show()
}
}
输出结果:
+----------------+
|uppercase_fruit |
+----------------+
|APPLE |
|ORANGE |
|BANANA |
+----------------+
在示例中,我们首先定义了一个自定义函数 toUpper
,它将字符串转换为大写。然后,我们创建了一个包含水果名称的 DataFrame,并使用 toUpper
函数将 fruit
列中的字符串转换为大写。最后,我们选择转换后的结果,并显示出来。
这个示例演示了如何定义和应用一个简单的用户自定义函数,以及如何将其应用于 Spark SQL 的 DataFrame 上进行数据处理和转换。用户可以根据自己的需求定义更复杂的自定义函数,并应用于更复杂的数据处理任务中。
示例2
下面是一个示例,演示如何通过配置和操作 UserDefinedFunction
对象来设置用户自定义函数(UDF)的属性。
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions._
object udfTest {
// 定义一个自定义函数,判断字符串长度是否大于等于指定长度
val stringLengthUDF = udf((s: String, length: Int) => s.length >= length)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder
.master("local[2]")
.appName("appName").config("", true)
.getOrCreate()
import spark.implicits._
// 创建一个 DataFrame
val df = spark.sparkContext.parallelize(Seq(("apple"), ("orange"), ("banana")))
.toDF("fruit")
// 使用自定义函数并设置属性
val udfObj = stringLengthUDF.withName("StringLengthUDF")
.asNonNullable()
.asNondeterministic()
// 将自定义函数应用到 DataFrame 的列
val result = df.select(udfObj(col("fruit"), lit(5)).as("is_long"))
// 显示结果
result.show()
}
}
输出结果:
+-------+
|is_long|
+-------+
| true |
| false |
| true |
+-------+
在示例中,我们首先定义了一个自定义函数 stringLengthUDF
,它接受一个字符串和一个整数参数,并返回该字符串的长度是否大于等于指定长度。然后,我们创建了一个包含水果名称的 DataFrame。接下来,我们对 stringLengthUDF
的 UserDefinedFunction
对象进行一系列属性设置:
- 使用
withName
方法将函数名称设置为 “StringLengthUDF”。 - 使用
asNonNullable
方法将函数设置为非可空。 - 使用
asNondeterministic
方法将函数设置为非确定性。
最后,我们将设置好的自定义函数应用于 DataFrame 的 fruit
列,并将结果保存在一个新的列 is_long
中。最终,我们显示了结果。
通过示例,我们展示了如何使用 UserDefinedFunction
对象来设置和操作用户自定义函数的属性。用户可以根据需要配置函数的可空性、确定性等属性,以满足具体的业务需求。
方法总结
这是 Apache Spark 中用于用户自定义函数(UDF)的功能。
-
定义 UserDefinedFunction 类:
UserDefinedFunction
是表示用户自定义函数的类,它接受函数对象f
、返回值类型dataType
和输入参数类型inputTypes
。 -
构造函数参数
-
f: AnyRef
:表示用户定义的函数对象。这个函数可以是任意类型(AnyRef),例如匿名函数、Lambda 表达式或方法引用。 -
dataType: DataType
:表示用户定义的函数的返回值类型。它指定了 UDF 的输出结果的数据类型,可以是 Spark SQL 中支持的任何有效数据类型,如整数、字符串、布尔值等。 -
inputTypes: Option[Seq[DataType]]
:表示用户定义的函数的输入参数类型。它是一个可选的序列,其中每个元素都指定了一个输入参数的数据类型。如果 UDF 不接受任何输入参数,则可以将其设置为 None。如果 UDF 接受多个输入参数,则可以使用 Seq[DataType] 指定每个参数的数据类型。
通过这些构造函数参数,
UserDefinedFunction
类能够接收用户定义的函数、返回值类型和输入参数类型,并在应用于 DataFrame 列时执行相应的操作。这样,用户就可以自定义函数并在 Spark SQL 中使用它们来处理和转换数据。 -
-
属性变量:
UserDefinedFunction
类包含以下属性变量:_nameOption
:存储函数名称的可选字符串。_nullable
:标识函数是否可以返回可空值的布尔值,默认为 true。_deterministic
:标识函数是否是确定性的,即给定相同的输入是否始终产生相同的输出,默认为 true。nullableTypes
:存储输入参数类型的可选序列,并记录参数是否可空。
-
方法定义:
nullable
方法:返回 UDF 是否可以返回可空值。deterministic
方法:返回 UDF 是否是确定性的。apply
方法:返回一个表达式,调用 UDF 并传入指定的列作为参数。copyAll
方法:复制当前对象的所有属性,并返回新的UserDefinedFunction
对象。withName
方法:更新UserDefinedFunction
的名称。asNonNullable
方法:将 UDF 更新为非可空。asNondeterministic
方法:将 UDF 更新为非确定性。
-
伴生对象 SparkUserDefinedFunction:
create
方法:根据给定的函数对象、返回值类型和输入参数模式创建一个新的UserDefinedFunction
对象。
该源码提供了用户定义自己的函数并在 Spark SQL 中使用的能力。通过创建和配置 UserDefinedFunction
对象,可以定义和应用 UDF,以对 DataFrame 进行复杂的数据处理和转换操作。用户可以根据需求设置 UDF 的属性,如名称、可空性和确定性等,以满足具体业务逻辑的要求。
中文源码
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.types.DataType
/**
* 用户自定义函数(User Defined Function,简称 UDF)。使用 `functions` 中的 `udf` 函数来创建一个 UDF。
*
* 例如:
* {{{
* // 定义一个根据分数判断是否大于 0.5 的 UDF
* val predict = udf((score: Double) => score > 0.5)
*
* // 在 DataFrame 中添加一个基于 score 列的预测列
* df.select( predict(df("score")) )
* }}}
*
* @since 1.3.0
*/
@InterfaceStability.Stable
case class UserDefinedFunction protected[sql] (
f: AnyRef,
dataType: DataType,
inputTypes: Option[Seq[DataType]]) {
private var _nameOption: Option[String] = None
private var _nullable: Boolean = true
private var _deterministic: Boolean = true
// 这是一个 `var`,为了保持这个 case class 的向后兼容性。
// TODO: 在 Spark 3.0 中重新审视这个 case class,并缩小公共接口。
private[sql] var nullableTypes: Option[Seq[Boolean]] = None
/**
* 当 UDF 可以返回可空值时返回 true。
*
* @since 2.3.0
*/
def nullable: Boolean = _nullable
/**
* 当 UDF 是确定性的时返回 true,即给定相同的输入是否始终产生相同的输出。
*
* @since 2.3.0
*/
def deterministic: Boolean = _deterministic
/**
* 返回一个表达式,调用 UDF 并传入指定的列作为参数。
*
* @since 1.3.0
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
// TODO: 确保此类仅通过 `SparkUserDefinedFunction.create()` 实例化,并始终设置 nullableTypes。
if (nullableTypes.isEmpty) {
nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))
}
if (inputTypes.isDefined) {
assert(inputTypes.get.length == nullableTypes.get.length)
}
Column(ScalaUDF(
f,
dataType,
exprs.map(_.expr),
nullableTypes.get,
inputTypes.getOrElse(Nil),
udfName = _nameOption,
nullable = _nullable,
udfDeterministic = _deterministic))
}
private def copyAll(): UserDefinedFunction = {
val udf = copy()
udf._nameOption = _nameOption
udf._nullable = _nullable
udf._deterministic = _deterministic
udf.nullableTypes = nullableTypes
udf
}
/**
* 使用给定的名称更新 UserDefinedFunction。
*
* @since 2.3.0
*/
def withName(name: String): UserDefinedFunction = {
val udf = copyAll()
udf._nameOption = Option(name)
udf
}
/**
* 将 UserDefinedFunction 更新为非可空。
*
* @since 2.3.0
*/
def asNonNullable(): UserDefinedFunction = {
if (!nullable) {
this
} else {
val udf = copyAll()
udf._nullable = false
udf
}
}
/**
* 将 UserDefinedFunction 更新为非确定性。
*
* @since 2.3.0
*/
def asNondeterministic(): UserDefinedFunction = {
if (!_deterministic) {
this
} else {
val udf = copyAll()
udf._deterministic = false
udf
}
}
}
// 这里使用的名称与 `UserDefinedFunction` 不同,以避免破坏自动生成的 UserDefinedFunction 对象的二进制兼容性。
private[sql] object SparkUserDefinedFunction {
def create(
f: AnyRef,
dataType: DataType,
inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = {
val inputTypes = if (inputSchemas.contains(None)) {
None
} else {
Some(inputSchemas.map(_.get.dataType))
}
val udf = new UserDefinedFunction(f, dataType, inputTypes)
udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
udf
}
}