TL; DR该函数应返回类org.apache.spark.sql.Row的对象.
Spark提供了UDF定义的两个主要变体.
>使用Scala反射的udf变体:
> def udf [RT](f :()?RT)(隐式arg0:TypeTag [RT]):UserDefinedFunction
> def udf [RT,A1](f:(A1)?RT)(隐式arg0:TypeTag [RT],arg1:TypeTag [A1]):UserDefinedFunction
> ……
> def udf [RT,A1,A2,…,A10](f:(A1,A2,…,A10)?RT)(隐式arg0:TypeTag [RT],arg1:TypeTag [A1],arg2 :TypeTag [A2],…,arg10:TypeTag [A10])
哪个定义
Scala closure of … arguments as user-defined function (UDF). The data types are automatically inferred based on the Scala closure’s signature.
这些变体在没有原子或代数数据类型的模式的情况下使用.例如,有问题的函数将在Scala中定义:
case class Price(value: Double, currency: String)
val df = Seq("1 USD").toDF("price")
val toPrice = udf((s: String) => scala.util.Try {
s split(" ") match {
case Array(price, currency) => Price(price.toDouble, currency)
}
}.toOption)
df.select(toPrice($"price")).show
// +----------+
// |UDF(price)|
// +----------+
// |[1.0, USD]|
// +----------+
在此变体中,返回类型是自动编码的.
由于它依赖于反射,因此该变体主要用于Scala用户.
>提供模式定义的udf变体(您在此处使用的变体).此变体的返回类型应与数据集[Row]的返回类型相同:
>正如在另一个答案中指出的那样,您只能使用SQL types mapping table中列出的类型(原子类型为盒装或未装箱,java.sql.Timestamp / java.sql.Date,以及高级集合).
>使用org.apache.spark.sql.Row表示复杂结构(结构/结构类型).不允许与代数数据类型或等效数据混合.例如(Scala代码)
struct>>
应该表达为
Row(1, Row("foo", Row(-1.0, 42))))
不
(1, ("foo", (-1.0, 42))))
或任何混合变体,如
Row(1, Row("foo", (-1.0, 42))))
提供此变体主要是为了确保Java互操作性.
在这种情况下(相当于有问题的那个),定义应类似于以下定义:
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val schema = StructType(Seq(
StructField("value", DoubleType, false),
StructField("currency", StringType, false)
))
val toPrice = udf((s: String) => scala.util.Try {
s split(" ") match {
case Array(price, currency) => Row(price.toDouble, currency)
}
}.getOrElse(null), schema)
df.select(toPrice($"price")).show
// +----------+
// |UDF(price)|
// +----------+
// |[1.0, USD]|
// | null|
// +----------+
排除异常处理的所有细微差别(通常UDF应该控制空输入,按照惯例优雅地处理格式错误的数据)Java等效应该看起来或多或少像这样:
UserDefinedFunction price = udf((String s) -> {
String[] split = s.split(" ");
return RowFactory.create(Double.parseDouble(split[0]), split[1]);
}, DataTypes.createStructType(new StructField[]{
DataTypes.createStructField("value", DataTypes.DoubleType, true),
DataTypes.createStructField("currency", DataTypes.StringType, true)
}));
语境:
为了给你一些上下文,这种区别也反映在API的其他部分.例如,您可以从架构和一系列行创建DataFrame:
def createDataFrame(rows: List[Row], schema: StructType): DataFrame
或使用一系列产品的反射
def createDataFrame[A <: product seq arg0: typetag dataframe>
但不支持混合变体.
换句话说,您应该提供可以使用RowEncoder编码的输入.
当然你通常不会使用udf来执行这样的任务:
import org.apache.spark.sql.functions._
df.withColumn("price", struct(
split($"price", " ")(0).cast("double").alias("price"),
split($"price", " ")(1).alias("currency")
))
有关: