前言
上篇介绍了Spark的Mllib机器学习工具ML扩展中的Pipeline,并就如何自定义Pipeline Stage模型中的Transformer模型进行讨论。本篇我们讨论Pipeline Stage的另一类模型Estimator,并基于Estimator实现异常点检测,扩展Spark的大数据分析能力。
一、 模型介绍
Spark Pipeline每个阶段Stage包含Transformer和Estimator两类抽象,在这两个抽象中,Transformer模型一般做数据转换用,其不需要获取数据知识,也不需要保存中间数据。而Estimator一般做预测、分析用,其要分析数据的含义,一般要经过一个训练的过程来获得数据知识,因此需要保存模型中间数据。
Spark封装了许多常用的机器学习算法以及转换操作,但并没有异常检测相关算法。我们考虑最简单的基于统计量分析的异常检测算法,其先对变量做一个描述性统计,进而查看哪些数据是不合理的。
这里可以利用统计学上的3σ法则,即若数据服从正太分布,数据有0.9973的概率集中在(μ-3σ,μ+3σ)区间内,数据超出这个范围的可能性不到0.3%,我们可以认为其是异常值。
为了我们的异常检测模型能在一定程度上学习数据分布特征,且在某些情况下能够直接用训练好的模型判断异常点,而不用全量数据重新训练,我们采用Estimator方式对模型进行实现。
二、基于Estimator的异常检测模型
Estimator模型需要通过一个训练的过程获取数据中的知识。现在假设我们有一份人群身高数据,需要自己定义一个基于Estimator模型的异常检测模型,实现一个简单的数据异常点检测功能。期望的效果是给数据新增一列“Abnormal”表示该人身高数值是否异常,是的话赋值1,否则赋值0。
定义该异常检测模型,数据通过Estimator类进行fit训练,保存训练数据,得到EstimatorModel,并由EstimatorModel使用3σ法则对数据进行转换操作,得到最终的结果。可见,为了实现基于Estimator的异常检测模型,我们需要自定义Estimator和EstimatorModel两个类。
1. 构造通用参数Trait
我们需要实现的两个类Estimator和EstimatorModel拥有相同的模型参数,因此为了统一参数设置,可以把参数抽出来单独定义一个Trait,之后让需要的类来继承他。该Trait中设置了inputCol和outputCol参数,以及他们的set方法。
trait MyEstimatorParams extends Params {
final val inputCol = new Param[String](this, "inputCol", "The input column")
final val outputCol = new Param[String](this, "outputCol", "The output column")
def setInputCol(value: String) = set(inputCol, value)
def setOutputCol(value: String) = set(outputCol, value)
}
2. 实现EstimatorModel
EstimatorModel指的是训练好的模型,其有经过训练后的模型参数,且和Transformer类似,需要实现数据的transform转化功能。
2.1 实现EstimatorModel类
这里我们定义MyEstimatorModel类,以及入参uid和datadf。其中datadf表示模型通过训练获得的参数,以DataFrame结构进行存储。此外,MyEstimatorModel还要继承Model以及前面自定义的通用参数Trait,即MyEstimatorParams。
class MyEstimatorModel (override val uid: String,val data_df: DataFrame) extends Model[MyEstimatorModel] with MyEstimatorParams {
override def copy(extra: ParamMap): MyEstimatorModel = {
val copied =new MyEstimatorModel(uid, data _df)
copyValues(copied, extra).setParent(parent)
}
}
2.2 重写transformSchema方法
TransformSchema方法确定Transformer操作中,所操作的DateFrame对象schema的变化。因为我们要会增加一列“Abnormal”表示是否是异常,所以这里我们首先要判断待计算列,即inputCol的类型是否是Double,不是的话报错。其次是要返回变化后的表结构,即增加了一列“Abnormal”的StructType。
override def transformSchema(schema: StructType): StructType = {
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != DoubleType) {
throw new Exception(s"Input type ${field.dataType} did not match input type DoubleType")
}
schema.add(StructField($(outputCol), DoubleType, false))
}
2.3 实现transform方法
transform方法实现Transformer对DataFrame所做的具体操作。该方法先从data_df得到需要的参数avg和std,表示待测数据的均值和标准差。之后Transform方法利用3σ法则,判断待测值与平均值的偏差是否超过3倍标准差,是的话则判断其为异常点。
override def transform(dataset: Dataset[_]): DataFrame = {
val avg = data_df.collect()(0).apply(0).toString().toDouble
val std = data_df.collect()(0).apply(1).toString().toDouble
val func = udf { label: Double =>{
if (abs(label - avg) > 3 * std) {
1
} else {
0
}
}}
dataset.select(col("*"), func(dataset($(inputCol))).as($(outputCol)))
}
3. 实现Estimator对象
Estimator类需要实现模型的训练功能,得到模型参数后返回训练好的模型EstimatorModel。
3.1 构造基本的Estimator类
这里我们自定义MyEstimator类,并继承Estimator以及MyEstimatorParams
class MyEstimator(override val uid: String) extends Estimator[MyEstimatorModel] with MyEstimatorParams {
def this() = this(Identifiable.randomUID("MyEstimatorParams"))
override def copy(extra: ParamMap): MyEstimator = {
defaultCopy(extra)
}
}
3.2 重写transformSchema方法
Estimator的TransformSchema方法与EstimatorModel的方法类似。
override def transformSchema(schema: StructType): StructType = {
val idx = schema.fieldIndex($(inputCol))
val field = schema.fields(idx)
if (field.dataType != DoubleType) {
throw new Exception(s"Input type ${field.dataType} did not match input type DoubleType")
}
schema.add(StructField($(outputCol), DoubleType, false))
}
3.3 重写fit方法
Estimator的fit方法描述了模型训练的过程。在本文的异常检测例子中,需要获取待测数据的均值与标准差,并返回训练好的模型EstimatorModel。
override def fit(dataset: Dataset[_]): MyEstimatorModel = {
import dataset.sparkSession.implicits._
val data_df = dataset.select(avg($(inputCol)).as("avg"), sqrt(var_pop($(inputCol))).as("std"))
val c = new MyEstimatorModel(uid, data_df)
c.setInputCol($(inputCol)).setOutputCol($(outputCol))
}
4. 实现可读写
4.1 重现DefaultParams读写对象
自定义的Estimator模型实现可读写,需要使用到DefaultParamsWriter和DefaultParamsReader对象的方法。然而在Spark源代码中,DefaultParamsWriter和MyDefaultParamsReader的方法是私有的,自建模型无法引用。
为了直接使用DefaultParamsWriter.saveMetadata、DefaultParamsReader. loadMetadata等方法,我们需要在源码中找出相关内容并重新实现。对于Spark2.2,相关内容位于源码spark.ml.util.ReadWrite.scala中。我们将其提取出来,并实现两个Object,分别是MyDefaultParamsWriter和MyDefaultParamsReader,内容大致如下。
object MyDefaultParamsWriter {
def saveMetadata…
def getMetadataToSave…
}
object MyDefaultParamsReader {
case class Metadata(
className: String,
…) {
def getParamValue…
}
def loadMetadata…
def parseMetadata…
def getAndSetParams…
}
4.2 实现EstimatorModel伴生对象
EstimatorModel伴生对象Object需继承MLReadable,并实现内部类MyEstimatorReader。该内部类继承MLReader并重写load方法。load方法确定了Pipeline如何从存储中读取MyEstimatorModel,即获取uid以及模型参数后新建MyEstimatorModel对象,并通过我们前面重现的MyDefaultParamsReader向对象写入param参数。
object MyEstimatorModel extends MLReadable[MyEstimatorModel] {
private class MyEstimatorReader extends MLReader[MyEstimatorModel] {
private val className = classOf[MyEstimatorModel].getName
override def load(path: String): MyEstimatorModel = {
val metadata = MyDefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data_df = sqlContext.read.parquet(dataPath)
val model = new MyEstimatorModel(metadata.uid, sum_df)
MyDefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
override def read: MLReader[MyEstimatorModel] = new MyEstimatorReader
override def load(path: String): MyEstimatorModel = super.load(path)
}
4.3 更新EstimatorModel类
更新EstimatorModel类分三步:1)继承MLWritable;2)实现内部类MyEstimatorWriter;3)重写write方法。其中MyEstimatorWriter重写的saveImpl方法确定将模型参数进行存储的过程。
class MyEstimatorModel(override val uid: String, val data_df: DataFrame) extends Model[MyEstimatorModel] with MyEstimatorParams with MLWritable {
override def copy…
override def transformSchema…
override def transform…
private[MyEstimatorModel] class MyEstimatorWriter(instance: MyEstimatorModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
MyDefaultParamsWriter.saveMetadata(instance, path, sc)
val dataPath = new Path(path, "data").toString
instance.data_df.repartition(1).write.parquet(dataPath)
}
}
override def write: MLWriter = new MyEstimatorWriter(this)
}
4.4 Estimator类更新
这里首先让Estimator类继承DefaultParamsReadable,其次是实现Estimator伴生对象,整体过程与Transformer类似。
class MyEstimator(override val uid: String) extends Estimator[MyEstimatorModel] with MyEstimatorParams with DefaultParamsWritable {…}
object MyEstimator extends DefaultParamsReadable[MyEstimator] {
override def load(path: String): MyEstimator = super.load(path)
}
三、测试
通过前面的步骤,自定义的Estimator异常检测模型的基本功能就已经实现了,下面我们写个程序测试一下:
val dataset = spark.createDataFrame(Seq(
("mike", 166.0),("tom", 175.0),("wade", 163.0),("bad", 660.0),
("james", 160.0),("black", 166.0),("angel", 162.0),("Emma", 177.0),
("weekn", 174.0),("kelly", 166.0),("grey", 140.0))).toDF("name", "height")
val est=new MyEstimator
est.setInputCol("height").setOutputCol("Abnormal")
val pipeline = new Pipeline().setStages(Array(est))
val model=pipeline.fit(dataset)
model.write.overwrite().save("model")
val r=model.transform(dataset)
r.show
结果如下:
这样,自定义Estimator异常检测模型的工作就全部完成了。