Spark Pipeline Stage模型自定义(二)

前言

上篇介绍了Spark的Mllib机器学习工具ML扩展中的Pipeline,并就如何自定义Pipeline Stage模型中的Transformer模型进行讨论。本篇我们讨论Pipeline Stage的另一类模型Estimator,并基于Estimator实现异常点检测,扩展Spark的大数据分析能力。

一、 模型介绍

Spark Pipeline每个阶段Stage包含Transformer和Estimator两类抽象,在这两个抽象中,Transformer模型一般做数据转换用,其不需要获取数据知识,也不需要保存中间数据。而Estimator一般做预测、分析用,其要分析数据的含义,一般要经过一个训练的过程来获得数据知识,因此需要保存模型中间数据。

Spark封装了许多常用的机器学习算法以及转换操作,但并没有异常检测相关算法。我们考虑最简单的基于统计量分析的异常检测算法,其先对变量做一个描述性统计,进而查看哪些数据是不合理的。

这里可以利用统计学上的3σ法则,即若数据服从正太分布,数据有0.9973的概率集中在(μ-3σ,μ+3σ)区间内,数据超出这个范围的可能性不到0.3%,我们可以认为其是异常值。

为了我们的异常检测模型能在一定程度上学习数据分布特征,且在某些情况下能够直接用训练好的模型判断异常点,而不用全量数据重新训练,我们采用Estimator方式对模型进行实现。

二、基于Estimator的异常检测模型

Estimator模型需要通过一个训练的过程获取数据中的知识。现在假设我们有一份人群身高数据,需要自己定义一个基于Estimator模型的异常检测模型,实现一个简单的数据异常点检测功能。期望的效果是给数据新增一列“Abnormal”表示该人身高数值是否异常,是的话赋值1,否则赋值0。

定义该异常检测模型,数据通过Estimator类进行fit训练,保存训练数据,得到EstimatorModel,并由EstimatorModel使用3σ法则对数据进行转换操作,得到最终的结果。可见,为了实现基于Estimator的异常检测模型,我们需要自定义Estimator和EstimatorModel两个类。

1. 构造通用参数Trait

我们需要实现的两个类Estimator和EstimatorModel拥有相同的模型参数,因此为了统一参数设置,可以把参数抽出来单独定义一个Trait,之后让需要的类来继承他。该Trait中设置了inputCol和outputCol参数,以及他们的set方法。

trait MyEstimatorParams extends Params {
    final val inputCol = new Param[String](this, "inputCol", "The input column")
    final val outputCol = new Param[String](this, "outputCol", "The output column")
    def setInputCol(value: String) = set(inputCol, value)
    def setOutputCol(value: String) = set(outputCol, value)
}

2. 实现EstimatorModel

EstimatorModel指的是训练好的模型,其有经过训练后的模型参数,且和Transformer类似,需要实现数据的transform转化功能。

2.1 实现EstimatorModel类

这里我们定义MyEstimatorModel类,以及入参uid和datadf。其中datadf表示模型通过训练获得的参数,以DataFrame结构进行存储。此外,MyEstimatorModel还要继承Model以及前面自定义的通用参数Trait,即MyEstimatorParams。

class MyEstimatorModel (override val uid: String,val data_df: DataFrame) extends Model[MyEstimatorModel] with MyEstimatorParams {
    override def copy(extra: ParamMap): MyEstimatorModel = {
         val copied =new MyEstimatorModel(uid, data _df)
         copyValues(copied, extra).setParent(parent)
    }
}
2.2 重写transformSchema方法

TransformSchema方法确定Transformer操作中,所操作的DateFrame对象schema的变化。因为我们要会增加一列“Abnormal”表示是否是异常,所以这里我们首先要判断待计算列,即inputCol的类型是否是Double,不是的话报错。其次是要返回变化后的表结构,即增加了一列“Abnormal”的StructType。

override def transformSchema(schema: StructType): StructType = {
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != DoubleType) {
        throw new Exception(s"Input type ${field.dataType} did not match input type DoubleType")
    }
    schema.add(StructField($(outputCol), DoubleType, false))
}
2.3 实现transform方法

transform方法实现Transformer对DataFrame所做的具体操作。该方法先从data_df得到需要的参数avg和std,表示待测数据的均值和标准差。之后Transform方法利用3σ法则,判断待测值与平均值的偏差是否超过3倍标准差,是的话则判断其为异常点。

override def transform(dataset: Dataset[_]): DataFrame = {
    val avg = data_df.collect()(0).apply(0).toString().toDouble
    val std = data_df.collect()(0).apply(1).toString().toDouble
    val func = udf { label: Double =>{
        if (abs(label - avg) > 3 * std) {
            1
        } else {
            0
        }
    }}
    dataset.select(col("*"), func(dataset($(inputCol))).as($(outputCol)))
}

3. 实现Estimator对象

Estimator类需要实现模型的训练功能,得到模型参数后返回训练好的模型EstimatorModel。

3.1 构造基本的Estimator类

这里我们自定义MyEstimator类,并继承Estimator以及MyEstimatorParams

class MyEstimator(override val uid: String) extends Estimator[MyEstimatorModel] with MyEstimatorParams {
    def this() = this(Identifiable.randomUID("MyEstimatorParams"))
    override def copy(extra: ParamMap): MyEstimator = {
        defaultCopy(extra)
    }
}
3.2 重写transformSchema方法

Estimator的TransformSchema方法与EstimatorModel的方法类似。

override def transformSchema(schema: StructType): StructType = {
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != DoubleType) {
        throw new Exception(s"Input type ${field.dataType} did not match input type DoubleType")
    }
    schema.add(StructField($(outputCol), DoubleType, false))
}
3.3 重写fit方法

Estimator的fit方法描述了模型训练的过程。在本文的异常检测例子中,需要获取待测数据的均值与标准差,并返回训练好的模型EstimatorModel。

override def fit(dataset: Dataset[_]): MyEstimatorModel = {
    import dataset.sparkSession.implicits._
    val data_df = dataset.select(avg($(inputCol)).as("avg"), sqrt(var_pop($(inputCol))).as("std")) 
    val c = new MyEstimatorModel(uid, data_df)
    c.setInputCol($(inputCol)).setOutputCol($(outputCol))
}

4. 实现可读写

4.1 重现DefaultParams读写对象

自定义的Estimator模型实现可读写,需要使用到DefaultParamsWriter和DefaultParamsReader对象的方法。然而在Spark源代码中,DefaultParamsWriter和MyDefaultParamsReader的方法是私有的,自建模型无法引用。

为了直接使用DefaultParamsWriter.saveMetadata、DefaultParamsReader. loadMetadata等方法,我们需要在源码中找出相关内容并重新实现。对于Spark2.2,相关内容位于源码spark.ml.util.ReadWrite.scala中。我们将其提取出来,并实现两个Object,分别是MyDefaultParamsWriter和MyDefaultParamsReader,内容大致如下。

object MyDefaultParamsWriter {
    def saveMetadata…
    def getMetadataToSave…
}
object MyDefaultParamsReader {
    case class Metadata(
      className: String,
      …) {
         def getParamValue…
     }
    def loadMetadata…
    def parseMetadata…
    def getAndSetParams…
}
4.2 实现EstimatorModel伴生对象

EstimatorModel伴生对象Object需继承MLReadable,并实现内部类MyEstimatorReader。该内部类继承MLReader并重写load方法。load方法确定了Pipeline如何从存储中读取MyEstimatorModel,即获取uid以及模型参数后新建MyEstimatorModel对象,并通过我们前面重现的MyDefaultParamsReader向对象写入param参数。

object MyEstimatorModel extends MLReadable[MyEstimatorModel] {
    private class MyEstimatorReader extends MLReader[MyEstimatorModel] {
        private val className = classOf[MyEstimatorModel].getName
        override def load(path: String): MyEstimatorModel = {
            val metadata = MyDefaultParamsReader.loadMetadata(path, sc, className)
            val dataPath = new Path(path, "data").toString
            val data_df = sqlContext.read.parquet(dataPath)
            val model = new MyEstimatorModel(metadata.uid, sum_df)
            MyDefaultParamsReader.getAndSetParams(model, metadata)
            model
        }
    }
    override def read: MLReader[MyEstimatorModel] = new MyEstimatorReader
    override def load(path: String): MyEstimatorModel = super.load(path)
}
4.3 更新EstimatorModel类

更新EstimatorModel类分三步:1)继承MLWritable;2)实现内部类MyEstimatorWriter;3)重写write方法。其中MyEstimatorWriter重写的saveImpl方法确定将模型参数进行存储的过程。

class MyEstimatorModel(override val uid: String, val data_df: DataFrame) extends Model[MyEstimatorModel] with MyEstimatorParams with MLWritable {
    override def copy…
    override def transformSchema…
    override def transform…
    private[MyEstimatorModel] class MyEstimatorWriter(instance: MyEstimatorModel) extends MLWriter {
        override protected def saveImpl(path: String): Unit = {
            MyDefaultParamsWriter.saveMetadata(instance, path, sc)
            val dataPath = new Path(path, "data").toString
            instance.data_df.repartition(1).write.parquet(dataPath)
        }
    }
    override def write: MLWriter = new MyEstimatorWriter(this)
}
4.4 Estimator类更新

这里首先让Estimator类继承DefaultParamsReadable,其次是实现Estimator伴生对象,整体过程与Transformer类似。

class MyEstimator(override val uid: String) extends Estimator[MyEstimatorModel] with MyEstimatorParams with DefaultParamsWritable {…}
object MyEstimator extends DefaultParamsReadable[MyEstimator] {
  override def load(path: String): MyEstimator = super.load(path)
}

三、测试

通过前面的步骤,自定义的Estimator异常检测模型的基本功能就已经实现了,下面我们写个程序测试一下:

val dataset = spark.createDataFrame(Seq(
("mike", 166.0),("tom", 175.0),("wade", 163.0),("bad", 660.0),
("james", 160.0),("black", 166.0),("angel", 162.0),("Emma", 177.0),
("weekn", 174.0),("kelly", 166.0),("grey", 140.0))).toDF("name", "height")
val est=new MyEstimator
est.setInputCol("height").setOutputCol("Abnormal")
val pipeline = new Pipeline().setStages(Array(est))
val model=pipeline.fit(dataset)
model.write.overwrite().save("model")
val r=model.transform(dataset)
r.show

结果如下:

0176d0bb0673805505dc3b66b6de771b757.jpg

这样,自定义Estimator异常检测模型的工作就全部完成了。

转载于:https://my.oschina.net/weekn/blog/1975845

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值