Spark MLlib中的SQLTransformer源码解析
文章目录
1. 源码适用场景
SQLTransformer
实现了由SQL语句定义的转换。
- 只支持SQL语法类似于’SELECT … FROM THIS …',其中’THIS’表示输入数据集的底层表。
- select子句指定要在输出中显示的字段、常量和表达式,它可以是任何Spark SQL支持的select子句。
- 用户还可以使用Spark SQL内置函数和UDF对这些选定的列进行操作。例如,[[SQLTransformer]]支持如下语句:
SELECT a, a + b AS a_b FROM __THIS__
SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
2. 多种主要用法及其代码示例
用法1:选择字段并添加新的计算列
import org.apache.spark.ml.feature.SQLTransformer
// 创建SQLTransformer实例
val sqlTransformer = new SQLTransformer()
.setStatement("SELECT a, a + b AS a_b FROM __THIS__") // 设置SQL语句
// 应用SQLTransformer转换数据集
val outputData = sqlTransformer.transform(inputData)
用法2:选择字段并应用函数
import org.apache.spark.ml.feature.SQLTransformer
// 创建SQLTransformer实例
val sqlTransformer = new SQLTransformer()
.setStatement("SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5") // 设置SQL语句
// 应用SQLTransformer转换数据集
val outputData = sqlTransformer.transform(inputData)
用法3:选择字段并进行分组聚合
import org.apache.spark.ml.feature.SQLTransformer
// 创建SQLTransformer实例
val sqlTransformer = new SQLTransformer()
.setStatement("SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b") // 设置SQL语句
// 应用SQLTransformer转换数据集
val outputData = sqlTransformer.transform(inputData)
3. 方法介绍
setStatement(value: String): this.type
:设置SQL语句。getStatement: String
:获取当前设置的SQL语句。transform(dataset: Dataset[_]): DataFrame
:根据SQL语句对输入数据集进行转换,并返回转换后的数据集。transformSchema(schema: StructType): StructType
:根据SQL语句推断输出数据集的结构。copy(extra: ParamMap): SQLTransformer
:创建并返回当前实例的副本。
4. 官方链接
5.中文源码
/**
* 实现由SQL语句定义的转换。
* 目前我们仅支持类似于 'SELECT ... FROM __THIS__ ...' 的SQL语法,
* 其中 '__THIS__' 表示输入数据集的底层表。
* SELECT子句指定输出中要显示的字段、常量和表达式,
* 它可以是任何Spark SQL支持的select子句。
* 用户还可以使用Spark SQL内置函数和UDFs对这些选定的列进行操作。
* 例如,[[SQLTransformer]]支持如下语句:
*
* {{{
* SELECT a, a + b AS a_b FROM __THIS__
* SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5
* SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b
* }}}
*/
@Since("1.6.0")
class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) extends Transformer
with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("sql"))
/**
* SQL语句参数。语句以字符串形式提供。
*
* @group param
*/
@Since("1.6.0")
final val statement: Param[String] = new Param[String](this, "statement", "SQL语句")
/** @group setParam */
@Since("1.6.0")
def setStatement(value: String): this.type = set(statement, value)
/** @group getParam */
@Since("1.6.0")
def getStatement: String = $(statement)
private val tableIdentifier: String = "__THIS__"
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val tableName = Identifiable.randomUID(uid)
dataset.createOrReplaceTempView(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
val result = dataset.sparkSession.sql(realStatement)
// 调用SessionCatalog.dropTempView以避免取消持久化可能缓存的数据集。
dataset.sparkSession.sessionState.catalog.dropTempView(tableName)
result
}
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
val spark = SparkSession.builder().getOrCreate()
val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty))
val dummyDF = spark.createDataFrame(dummyRDD, schema)
val tableName = Identifiable.randomUID(uid)
val realStatement = $(statement).replace(tableIdentifier, tableName)
dummyDF.createOrReplaceTempView(tableName)
val outputSchema = spark.sql(realStatement).schema
spark.catalog.dropTempView(tableName)
outputSchema
}
@Since("1.6.0")
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
}
@Since("1.6.0")
object SQLTransformer extends DefaultParamsReadable[SQLTransformer] {
@Since("1.6.0")
override def load(path: String): SQLTransformer = super.load(path)
}