转自:http://blog.csdn.net/yizishou/article/details/52398665
需求
将DataFrame中的StructType类型字段下的所有内容转换为Json字符串。
Spark版本: 1.6.1
思路
- DataFrame有toJSON方法,可将每个Row都转为一个Json字符串,并返回RDD[String]
- DataFrame.write.json方法,可将数据写为Json格式文件
跟踪上述两处代码,发现最终都会调用Spark源码中的org.apache.spark.sql.execution.datasources.json.JacksonGenerator类,使用Jackson,根据传入的StructType、JsonGenerator和InternalRow,生成Json字符串。
开发
我们的函数只需传入一个参数,就是需要转换的列,因此需要实现org.apache.spark.sql.catalyst.expressions包下的UnaryExpression。
后续对功能进行了扩展,不是StructType类型的输入也可以转换。
- package org.apache.spark.sql.catalyst.expressions
-
-
- import java.io.CharArrayWriter
-
-
- import org.apache.spark.sql.catalyst.InternalRow
- import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
- import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenContext
- import org.apache.spark.sql.catalyst.expressions.codegen.GeneratedExpressionCode
- import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
- import org.apache.spark.sql.types.DataType
- import org.apache.spark.sql.types.Metadata
- import org.apache.spark.sql.types.StringType
- import org.apache.spark.sql.types.StructField
- import org.apache.spark.sql.types.StructType
-
-
- import com.fasterxml.jackson.core.JsonFactory
- import org.apache.spark.unsafe.types.UTF8String
-
-
- /**
- * 将StructType类型的字段转换为Json String
- * @author yizhu.sun 2016年8月30日
- */
- case class Json(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
-
-
- override def dataType: DataType = StringType
- override def inputTypes: Seq[DataType] = Seq(child.dataType)
-
-
- val inputStructType: StructType = child.dataType match {
- case st: StructType => st
- case _ => StructType(Seq(StructField("col", child.dataType, child.nullable, Metadata.empty)))
- }
-
-
- override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
-
-
- // 参照 org.apache.spark.sql.DataFrame.toJSON
- // 参照 org.apache.spark.sql.execution.datasources.json.JsonOutputWriter.writeInternal
- protected override def nullSafeEval(data: Any): UTF8String = {
- val writer = new CharArrayWriter
- val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
- val internalRow = child.dataType match {
- case _: StructType => data.asInstanceOf[InternalRow]
- case _ => InternalRow(data)
- }
- JacksonGenerator(inputStructType, gen)(internalRow)
- gen.flush
- gen.close
- val json = writer.toString
- UTF8String.fromString(
- child.dataType match {
- case _: StructType => json
- case _ => json.substring(json.indexOf(":") + 1, json.lastIndexOf("}"))
- })
- }
-
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val writer = ctx.freshName("writer")
- val gen = ctx.freshName("gen")
- val st = ctx.freshName("st")
- val json = ctx.freshName("json")
- val typeJson = inputStructType.json
- def getDataExp(data: Any) =
- child.dataType match {
- case _: StructType => s"$data"
- case _ => s"new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(new Object[]{$data})"
- }
- def formatJson(json: String) =
- child.dataType match {
- case _: StructType => s"$json"
- case _ => s"""$json.substring($json.indexOf(":") + 1, $json.lastIndexOf("}"))"""
- }
- nullSafeCodeGen(ctx, ev, (row) => {
- s"""
- | com.fasterxml.jackson.core.JsonGenerator $gen = null;
- | try {
- | org.apache.spark.sql.types.StructType $st = ${classOf[Json].getName}.getStructType("${typeJson.replace("\"", "\\\"")}");
- | java.io.CharArrayWriter $writer = new java.io.CharArrayWriter();
- | $gen = new com.fasterxml.jackson.core.JsonFactory().createGenerator($writer).setRootValueSeparator(null);
- | org.apache.spark.sql.execution.datasources.json.JacksonGenerator.apply($st, $gen, ${getDataExp(row)});
- | $gen.flush();
- | String $json = $writer.toString();
- | ${ev.value} = UTF8String.fromString(${formatJson(json)});
- | } catch (Exception e) {
- | ${ev.isNull} = true;
- | } finally {
- | if ($gen != null) $gen.close();
- | }
- """.stripMargin
- })
- }
-
-
- }
-
-
- object Json {
-
-
- val structTypeCache = collection.mutable.Map[String, StructType]() // [json, type]
-
-
- def getStructType(json: String): StructType = {
- structTypeCache.getOrElseUpdate(json, {
- println(">>>>> get StructType from json:")
- println(json)
- DataType.fromJson(json).asInstanceOf[StructType]
- })
- }
-
-
- }
注册
注意,SQLContext.functionRegistry的可见性为protected[sql]
- val (name, (info, builder)) = FunctionRegistry.expression[Json]("json")
-
- sqlContext.functionRegistry.registerFunction(name, info, builder)
测试
- val subSchema = StructType(Array(
- StructField("a", StringType, true),
- StructField("b", StringType, true),
- StructField("c", IntegerType, true)))
-
- val schema = StructType(Array(
- StructField("x", subSchema, true)))
-
- val rdd = sc.makeRDD(Seq(Row(Row("12", null, 123)), Row(Row(null, "2222", null))))
-
- val df = sqlContext.createDataFrame(rdd, schema)
-
- df.registerTempTable("df")
-
- import sqlContext.sql
-
- sql("select x, x.a from df").show
- sql("select x, x.a from df").printSchema
- sql("select json(x), json(x.a) from df").show
- sql("select json(x), json(x.a) from df").printSchema
结果
- +----------------+----+
- |x |a |
- +----------------+----+
- |[12,null,123] |12 |
- |[null,2222,null]|null|
- +----------------+----+
-
- root
- |-- x: struct (nullable = true)
- | |-- a: string (nullable = true)
- | |-- b: string (nullable = true)
- | |-- c: integer (nullable = true)
- |-- a: string (nullable = true)
-
- >>>>> get StructType from json:
- {"type":"struct","fields":[{"name":"a","type":"string","nullable":true,"metadata":{}},{"name":"b","type":"string","nullable":true,"metadata":{}},{"name":"c","type":"integer","nullable":true,"metadata":{}}]}
- >>>>> get StructType from json:
- {"type":"struct","fields":[{"name":"col","type":"string","nullable":true,"metadata":{}}]}
-
- +------------------+----+
- |_c0 |_c1 |
- +------------------+----+
- |{"a":"12","c":123}|"12"|
- |{"b":"2222"} |null|
- +------------------+----+
-
- root
- |-- _c0: string (nullable = true)
- |-- _c1: string (nullable = true)
需要注意的点
- 使用SparkSQL自定义函数一般有两种方法,一种是使用开放的api注册简单函数,即调用sqlContext.udf.register方法。另一种就是使用SparkSQL内置函数的注册方法(本例就是使用的这种方法)。前者优势是开发简单,但是实现不了较为复杂的功能,例如本例中需要获取传入的InternalRow的StructType,或者需要实现类似 def fun(arg: Seq[T]): T 这种泛型相关的功能(sqlContext.udf.register的注册方式无法注册返回值为Any的函数)。
- 本例中实现genCode函数时遇到了困难,即需要在生成的Java代码中构建StructType对象。这个最终通过序列化的思路解决,先使用StructType.json方法将StructType对象序列化为String,然后在Java代码中调用DataType.fromJson反序列化为StructType对象。