java jackson jobject_序列化 – 有没有办法在Spark ML Pipeline中序列化自定义Transformer...

我使用ML管道与各种自定义UDF变换器.我正在寻找的是一种序列化/反序列化这个管道的方法.

我使用了序列化PipelineModel

ObjectOutputStream.write()

但是每当我尝试反序列化我正在拥有的管道时:

java.lang.ClassNotFoundException: org.sparkexample.DateTransformer

DateTransformer在哪里是我的自定义变换器.是否有任何方法/接口可以实现正确的序列化?

我发现有

MLWritable

我的类可能实现的接口(DateTransformer扩展Transfrormer)但是找不到它的有用示例.

最佳答案 简短的回答是你不能,至少不容易.

开发人员竭尽全力尽可能地增加新的变压器/估算器.基本上org.apache.spark.ml.util.ReadWrite中的所有内容都是私有的(MLWritable和MLReadable除外),因此无法使用任何实用方法/类/对象.还有(我相信你已经发现了)绝对没有关于如何做到的文档,但是好的代码文件本身对吗?!

挖掘org.apache.spark.ml.util.ReadWrite和org.apache.spark.ml.feature.HashingTF中的代码,似乎需要覆盖MLWritable.write和MLReadable.read.似乎包含实际保存/加载实现的DefaultParamsWriter和DefaultParamsReader正在保存并加载一堆元数据:

>上课

>时间戳

> sparkVersion

> uid

> paramMap

>(可选,额外元数据)

所以任何实现都至少需要覆盖那些,并且不需要学习任何模型的变压器可能就是这样.需要安装的模型还需要在保存/写入的实现中保存该数据 – 例如,这是LocalLDAModel执行的操作(https://github.com/apache/spark/blob/v1.6.3/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala#L523),因此学习的模型只是保存为镶木地板文件(似乎)

val data = sqlContext.read.parquet(dataPath)

.select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",

"gammaShape")

.head()

作为测试,我复制了似乎需要的org.apache.spark.ml.util.ReadWrite中的所有内容,并测试了以下变压器,它没有做任何有用的事情.

警告:这几乎肯定是错误的做法,并且很可能在将来破裂.我真诚地希望我误解了一些事情,有人会纠正我如何真正创建一个可以序列化/反序列化的变压器

如果你使用的是2.x,这是针对spark 1.6.3并且可能已经被破坏了

import org.apache.spark.sql.types._

import org.apache.spark.ml.param._

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext

import org.apache.spark.ml.Transformer

import org.apache.spark.ml.util.{Identifiable, MLReadable, MLReader, MLWritable, MLWriter}

import org.apache.spark.sql.{SQLContext, DataFrame}

import org.apache.spark.mllib.linalg._

import org.json4s._

import org.json4s.JsonDSL._

import org.json4s.jackson.JsonMethods._

object CustomTransform extends DefaultParamsReadable[CustomTransform] {

/* Companion object for deserialisation */

override def load(path: String): CustomTransform = super.load(path)

}

class CustomTransform(override val uid: String)

extends Transformer with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("customThing"))

def setInputCol(value: String): this.type = set(inputCol, value)

def setOutputCol(value: String): this.type = set(outputCol, value)

def getOutputCol(): String = getOrDefault(outputCol)

val inputCol = new Param[String](this, "inputCol", "input column")

val outputCol = new Param[String](this, "outputCol", "output column")

override def transform(dataset: DataFrame): DataFrame = {

val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())

import sqlContext.implicits._

val outCol = extractParamMap.getOrElse(outputCol, "output")

val inCol = extractParamMap.getOrElse(inputCol, "input")

val transformUDF = udf({ vector: SparseVector =>

vector.values.map( _ * 10 )

// WHAT EVER YOUR TRANSFORMER NEEDS TO DO GOES HERE

})

dataset.withColumn(outCol, transformUDF(col(inCol)))

}

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

override def transformSchema(schema: StructType): StructType = {

val outputFields = schema.fields :+ StructField(extractParamMap.getOrElse(outputCol, "filtered"), new VectorUDT, nullable = false)

StructType(outputFields)

}

}

trait DefaultParamsWritable extends MLWritable { self: Params =>

override def write: MLWriter = new DefaultParamsWriter(this)

}

trait DefaultParamsReadable[T] extends MLReadable[T] {

override def read: MLReader[T] = new DefaultParamsReader

}

class DefaultParamsWriter(instance: Params) extends MLWriter {

override protected def saveImpl(path: String): Unit = {

DefaultParamsWriter.saveMetadata(instance, path, sc)

}

}

object DefaultParamsWriter {

/**

* Saves metadata + Params to: path + "/metadata"

* - class

* - timestamp

* - sparkVersion

* - uid

* - paramMap

* - (optionally, extra metadata)

* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.

* @param paramMap If given, this is saved in the "paramMap" field.

* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using

* [[org.apache.spark.ml.param.Param.jsonEncode()]].

*/

def saveMetadata(

instance: Params,

path: String,

sc: SparkContext,

extraMetadata: Option[JObject] = None,

paramMap: Option[JValue] = None): Unit = {

val uid = instance.uid

val cls = instance.getClass.getName

val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]

val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>

p.name -> parse(p.jsonEncode(v))

}.toList))

val basicMetadata = ("class" -> cls) ~

("timestamp" -> System.currentTimeMillis()) ~

("sparkVersion" -> sc.version) ~

("uid" -> uid) ~

("paramMap" -> jsonParams)

val metadata = extraMetadata match {

case Some(jObject) =>

basicMetadata ~ jObject

case None =>

basicMetadata

}

val metadataPath = new Path(path, "metadata").toString

val metadataJson = compact(render(metadata))

sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)

}

}

class DefaultParamsReader[T] extends MLReader[T] {

override def load(path: String): T = {

val metadata = DefaultParamsReader.loadMetadata(path, sc)

val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))

val instance =

cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]

DefaultParamsReader.getAndSetParams(instance, metadata)

instance.asInstanceOf[T]

}

}

object DefaultParamsReader {

/**

* All info from metadata file.

*

* @param params paramMap, as a [[JValue]]

* @param metadata All metadata, including the other fields

* @param metadataJson Full metadata file String (for debugging)

*/

case class Metadata(

className: String,

uid: String,

timestamp: Long,

sparkVersion: String,

params: JValue,

metadata: JValue,

metadataJson: String)

/**

* Load metadata from file.

*

* @param expectedClassName If non empty, this is checked against the loaded metadata.

* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata

*/

def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {

val metadataPath = new Path(path, "metadata").toString

val metadataStr = sc.textFile(metadataPath, 1).first()

val metadata = parse(metadataStr)

implicit val format = DefaultFormats

val className = (metadata \ "class").extract[String]

val uid = (metadata \ "uid").extract[String]

val timestamp = (metadata \ "timestamp").extract[Long]

val sparkVersion = (metadata \ "sparkVersion").extract[String]

val params = metadata \ "paramMap"

if (expectedClassName.nonEmpty) {

require(className == expectedClassName, s"Error loading metadata: Expected class name" +

s" $expectedClassName but found class name $className")

}

Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)

}

/**

* Extract Params from metadata, and set them in the instance.

* This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].

*/

def getAndSetParams(instance: Params, metadata: Metadata): Unit = {

implicit val format = DefaultFormats

metadata.params match {

case JObject(pairs) =>

pairs.foreach { case (paramName, jsonValue) =>

val param = instance.getParam(paramName)

val value = param.jsonDecode(compact(render(jsonValue)))

instance.set(param, value)

}

case _ =>

throw new IllegalArgumentException(

s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")

}

}

/**

* Load a [[Params]] instance from the given path, and return it.

* This assumes the instance implements [[MLReadable]].

*/

def loadParamsInstance[T](path: String, sc: SparkContext): T = {

val metadata = DefaultParamsReader.loadMetadata(path, sc)

val cls = Class.forName(metadata.className, true, Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader))

cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)

}

}

有了它,您可以在管道中使用CustomTransformer并保存/加载管道.我在火花壳中测试得相当快,它似乎工作但肯定不是很漂亮.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值