需求
将DataFrame中的StructType类型字段下的所有内容转换为Json字符串。
spark版本: 1.6.1
思路
- DataFrame有toJSON方法,可将每个Row都转为一个Json字符串,并返回RDD[String]
- DataFrame.write.json方法,可将数据写为Json格式文件
跟踪上述两处代码,发现最终都会调用Spark源码中的org.apache.spark.sql.execution.datasources.json.JacksonGenerator类,使用Jackson,根据传入的StructType、JsonGenerator和InternalRow,生成Json字符串。
开发
我们的函数只需传入一个参数,就是需要转换的列,因此需要实现org.apache.spark.sql.catalyst.expressions包下的UnaryExpression。
后续对功能进行了扩展,不是StructType类型的输入也可以转换。
package org.apache.spark.sql.catalyst.expressions
import java.io.CharArrayWriter
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.unsafe.types.UTF8String
/**
* 将StructType类型的字段转换为Json String
* @author yizhu.sun 2016年8月30日
*/
case class Json(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(child.dataType)
val inputStructType: StructType = child.dataType match {
case st: StructType => st
case _ => StructType(Seq(StructField("col", child.dataType, child.nullable, Metadata.empty)))
}
override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
// 参照 org.apache.spark.sql.DataFrame.toJSON
// 参照 org.apache.spark.sql.execution.datasources.json.JsonOutputWriter.writeInternal
protected override def nullSafeEval(data: Any): UTF8String = {
val writer = new CharArrayWriter
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
val internalRow = child.dataType match {
case _: StructType => data.asInstanceOf[InternalRow]
case _ => InternalRow(data)
}
JacksonGenerator(inputStructType, gen)(internalRow)
gen.flush
gen.close
val json = writer.toString
UTF8String.fromString(
child.dataType match {
case _: StructType => json
case _ => json.substring(json.indexOf(":") + 1, json.lastIndexOf("}"))
})
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val writer = ctx.freshName("writer")
val gen = ctx.freshName("gen")
val st = ctx.freshName("st")
val json = ctx.freshName("json")
val typeJson = inputStructType.json
def getDataExp(data: Any) =
child.dataType match {
case _: StructType => s"$data"
case _ => s"new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{$data})"
}
def formatJson(json: String) =
child.dataType match {
case _: StructType => s"$json"
case _ => s"""$json.substring($json.indexOf(":") + 1, $json.lastIndexOf("}"))"""
}
nullSafeCodeGen(ctx, ev, (row) => {
s"""
| com.fasterxml.jackson.core.JsonGenerator $gen = null;
| try {
| org.apache.spark.sql.types.StructType $st = ${classOf[Json].getName}.getStructType("${typeJson.replace("\"", "\\\"")}");
| java.io.CharArrayWriter $writer = new java.io.CharArrayWriter();
| $gen = new com.fasterxml.jackson.core.JsonFactory().createGenerator($writer).setRootValueSeparator(null);
| org.apache.spark.sql.execution.datasources.json.JacksonGenerator.apply($st, $gen, ${getDataExp(row)});
| $gen.flush();
| String $json = $writer.toString();
| ${ev.value} = UTF8String.fromString(${formatJson(json)});
| } catch (Exception e) {
| ${ev.isNull} = true;
| } finally {
| if ($gen != null) $gen.close();
| }
""".stripMargin
})
}
}
object Json {
val structTypeCache = collection.mutable.Map[String, StructType]() // [json, type]
def getStructType(json: String): StructType = {
structTypeCache.getOrElseUpdate(json, {
println(">>>>> get StructType from json:")
println(json)
DataType.fromJson(json).asInstanceOf[StructType]
})
}
}
注册
注意,SQLContext.functionRegistry的可见性为protected[sql]
val (name, (info, builder)) = FunctionRegistry.expression[Json]("json")
sqlContext.functionRegistry.registerFunction(name, info, builder)
测试
val subSchema = StructType(Array(
StructField("a", StringType, true),
StructField("b", StringType, true),
StructField("c", IntegerType, true)))
val schema = StructType(Array(
StructField("x", subSchema, true)))
val rdd = sc.makeRDD(Seq(Row(Row("12", null, 123)), Row(Row(null, "2222", null))))
val df = sqlContext.createDataFrame(rdd, schema)
df.registerTempTable("df")
import sqlContext.sql
sql("select x, x.a from df").show
sql("select x, x.a from df").printSchema
sql("select json(x), json(x.a) from df").show
sql("select json(x), json(x.a) from df").printSchema
结果
+----------------+----+
|x |a |
+----------------+----+
|[12,null,123] |12 |
|[null,2222,null]|null|
+----------------+----+
root
|-- x: struct (nullable = true)
| |-- a: string (nullable = true)
| |-- b: string (nullable = true)
| |-- c: integer (nullable = true)
|-- a: string (nullable = true)
>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}},{"name":"c","type":"integer","nullable":true,"metadata":{}}]}
>>>>> get StructType from json:
{"type":"struct","fields":[{"name":"col","type":"string","nullable":true,"metadata":{}}]}
+------------------+----+
|_c0 |_c1 |
+------------------+----+
|{"a":"12","c":123}|"12"|
|{"b":"2222"} |null|
+------------------+----+
root
|-- _c0: string (nullable = true)
|-- _c1: string (nullable = true)
需要注意的点
- 使用SparkSQL自定义函数一般有两种方法,一种是使用开放的api注册简单函数,即调用sqlContext.udf.register方法。另一种就是使用SparkSQL内置函数的注册方法(本例就是使用的这种方法)。前者优势是开发简单,但是实现不了较为复杂的功能,例如本例中需要获取传入的InternalRow的StructType,或者需要实现类似 def fun(arg: Seq[T]): T 这种泛型相关的功能(sqlContext.udf.register的注册方式无法注册返回值为Any的函数)。
- 本例中实现genCode函数时遇到了困难,即需要在生成的Java代码中构建StructType对象。这个最终通过序列化的思路解决,先使用StructType.json方法将StructType对象序列化为String,然后在Java代码中调用DataType.fromJson反序列化为StructType对象。