Spark特征处理之RFormula源码解析

##RFormula简单介绍
RFormula通过R模型公式来操作列。
支持R操作中的部分操作包括‘~’, ‘.’, ‘:’, ‘+’以及‘-‘。

1、 ~分隔目标和对象

2、 +合并对象,“+0”意味着删除空格

3、-删除一个对象,“-1”表示删除空格

4、 :交互(数值相乘,类别二值化)

5、 . 除了目标列的全部列

假设a和b为两列:

1、 y ~ a + b表示模型y ~ w0 + w1 * a +w2 * b其中w0为截距,w1和w2为相关系数

2、 y ~a + b + a:b – 1表示模型y ~ w1* a + w2 * b + w3 * a * b,其中w1,w2,w3是相关系数

RFormula产生一个向量特征列以及一个double或者字符串标签列。如果用R进行线性回归,则对String类型的输入列进行one-hot编码、对数值型的输入列进行double类型转化。如果类别列是字符串类型,它将通过StringIndexer转换为double类型。如果标签列不存在,则输出中将通过规定的响应变量创造一个标签列。
##代码示例

/**
  * Created by hhy on 2017/12/05.
  */
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.sql.SparkSession
object RFormulaDemo {
  def main(args: Array[String]): Unit = {
    val spark=SparkSession.builder().appName(" ").master("local").getOrCreate()
    val dataset = spark.createDataFrame(Seq(
      (7,10, "US", 18, 1.0),
      (8,22, "CA", 12, 0.0),
      (9,100,"CA", 15, 0.0),
      (10,29,"CA", 15, 1.0),
      (11,88,"CA", 15, 0.0),
      (12,99,"CA", 2, 0.0)
    )).toDF("id", "count","country", "hour", "clicked")
/**
*country列6个不同取值时候占了五个维度  五个不同取值时候占了四个维度
*四个不同取值时候占了三个维度  三个不同取值占了两维度  两个不同取值占
*了一个维度,另外我们还操作了非StringType类型的hour 和 count列 因此*在country列所占维度基础上 再加上两个维度,就是所形成的新列features 
*该列值是一个向量  由上面组成的维度构成
*/
    val formula = new RFormula()
      .setFormula("clicked ~ country + hour+count")
      .setFeaturesCol("features")
      .setLabelCol("label")

    val output = formula.fit(dataset).transform(dataset)
    output.show()
    //output.write.json("spark-warehouse/Rformula")
   // output.select("features", "label").show()

  }
}

其中,如果我们查看输出列features的值时候,是按照系数向量存储的,如果该列维度小于3(或者4 )自己试验 忘记了,,那么正常输出,比如country列只有一个数值 那么我们用0就可以表示了 构成的向量格式如下:
[0,hour列值,count列值]
[0,hour列值,count列值]
[0,hour列值,count列值]
[0,hour列值,count列值]
如果country列不止一个值 在进行onehot编码时候该列肯定不能使用一维就可以表示完所有取值,形成的新列features大于等于4时候就要输出稀疏向量形式了结果如下所示:
这里写图片描述

##源码解析

class RFormula(override val uid: String)
  extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable {

  /**
  *Identifiable.randomUID("rFormula")的作用是生成一个
  *以rFormula为前缀     然后加上下划线_      然后加上
  *UUID.randomUUID().toString.takeRight(12)  
  *12个随机十六进制字符
  *具体形式:    “rFormula_12个十六进制字符”
  *this指向了当前RFormula 然后为其参数uid赋值
  */
  def this() = this(Identifiable.randomUID("rFormula"))

  /**
   * RFormula参数,String类型的参数  
   */
  val formula: Param[String] = new Param(this, "formula", "R model formula")

  /**
   * 设置R公式为RFormula转换器  使用之前必须先调用这个函数设置参数
   * 例如"y ~ x + z"
   */
  def setFormula(value: String): this.type = set(formula, value)

  /** 得到RFormula的参数使用了  ${变量} 的形式*/
  def getFormula: String = $(formula)

  /** 设置得到的新列的列名字 */
  def setFeaturesCol(value: String): this.type = set(featuresCol, value)

  /** 针对R公式  设置label列的列名*/
  def setLabelCol(value: String): this.type = set(labelCol, value)

  /**
   * 将label列索引化
   * 一般情况我们我们只索引化字符串类型的label列
   * 在分类算法中,我们设置其为true,及时该列是数值类型
   * 我们也可以索引化
   */
  val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel",
    "Force to index label whether it is numeric or string")
  setDefault(forceIndexLabel -> false)

  /**获取forceIndexLabel变量值 */
  def getForceIndexLabel: Boolean = $(forceIndexLabel)

  /** 设置forceIndexLabel变量的值*/
  def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value)

  /**
   *是否特殊化拟合截距
  */
  private[ml] def hasIntercept: Boolean = {
    require(isDefined(formula), "Formula must be defined first.")
    RFormulaParser.parse($(formula)).hasIntercept
  }

  /**对于RFormula核心部分*/
  override def fit(dataset: Dataset[_]): RFormulaModel = {
    transformSchema(dataset.schema, logging = true)
    require(isDefined(formula), "Formula must be defined first.")
    /**解析给的R公式,返回类型ParsedRFormula
    *RFormulaParser继承了RegexParsers
    */
    val parsedFormula = RFormulaParser.parse($(formula))
    /** 返回类型ResolvedRFormula :将RFormula terms转为列名
    *  该类其中三个参数:
    *    label:String  列名;
    *    terms:Seq[Seq[String]] the simplified terms of the R formula
    *    hasIntercept:Boolean  是否特殊化拟合截距
    */
    val resolvedFormula = parsedFormula.resolve(dataset.schema)
    
    /**定义了一个存放PipelineStage的集合*/
    val encoderStages = ArrayBuffer[PipelineStage]()

    val prefixesToRewrite = mutable.Map[String, String]()  //重写前缀
    val tempColumns = ArrayBuffer[String]()
    def tmpColumn(category: String): String = {
      val col = Identifiable.randomUID(category)    // 返回String类型   字符串形式   category_12个十六进制字符 然后存储到tempColumns中
      tempColumns += col
      col
    }

// First we index each string column referenced by the input terms.将string类型的列索引化
    val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
      dataset.schema(term) match {
        case column if column.dataType == StringType =>
          val indexCol = tmpColumn("stridx")  //设置每个列操作后的输出列名  stridx_12个16进制字符 组成的字符串
          encoderStages += new StringIndexer() // encoderStages是一个集合 存放了PipelineStage主要就是对每一个string类型列创建了一个StringIndexer操作
            .setInputCol(term)
            .setOutputCol(indexCol)
            //替换一下前缀标识
          prefixesToRewrite(indexCol + "_") = term + "_"      
          (term, indexCol)
        case _ =>
          (term, term)
      }
    }.toMap

    // Then we handle one-hot encoding and interactions between terms.
    val encodedTerms = resolvedFormula.terms.map {
      case Seq(term) if dataset.schema(term).dataType == StringType =>
        val encodedCol = tmpColumn("onehot")
        encoderStages += new OneHotEncoder()  //对每个string类型的列 继续增加PipelineStage操作 这个才做是onehot的操作  000  001 010 100 可以表示四个不同值不用110 011类似的
          .setInputCol(indexed(term))
          .setOutputCol(encodedCol)
        prefixesToRewrite(encodedCol + "_") = term + "_"
        encodedCol
      case Seq(term) =>
        term
      case terms =>
        val interactionCol = tmpColumn("interaction")
        encoderStages += new Interaction()
          .setInputCols(terms.map(indexed).toArray)
          .setOutputCol(interactionCol)
        prefixesToRewrite(interactionCol + "_") = ""
        interactionCol
    }

    encoderStages += new VectorAssembler(uid)  //继续添加通道  操作是将若干个列向量合并为一列  设置输入列  输出列参数
      .setInputCols(encodedTerms.toArray)
      .setOutputCol($(featuresCol))
    encoderStages += new VectorAttributeRewriter($(featuresCol), prefixesToRewrite.toMap)//通过前缀替换重写向量属性名字这里是将StringIndexer操作输出的strid_   以及onehot操作命名的onehot_前缀替换掉  统一改为了term_  term是R公式的simplified terms
    encoderStages += new ColumnPruner(tempColumns.toSet)  //移除临时列

/**如果数据集中包含了给出的R公式中的label列,并且该列数据类型是String类型,或者我们设置的参数变量forceIndexLabel为true那么将label列使用StringIndexer索引化,此处并未执行,只是将这个转换器放到了 存储Piplinestage的encoderStages集合中了 */
    if ((dataset.schema.fieldNames.contains(resolvedFormula.label) &&
      dataset.schema(resolvedFormula.label).dataType == StringType) || $(forceIndexLabel)) {
      encoderStages += new StringIndexer()
        .setInputCol(resolvedFormula.label)
        .setOutputCol($(labelCol))
    }
/**调用fit执行encoderStages里面的PipelineStage,Pipeline的参数uid等于RFormula的参数uid*/
    val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
    copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this))
  }


  // optimistic schema; does not contain any ML attributes
  override def transformSchema(schema: StructType): StructType = {
    require(!hasLabelCol(schema) || !$(forceIndexLabel),
      "If label column already exists, forceIndexLabel can not be set with true.")
    if (hasLabelCol(schema)) {
      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
    } else {
      StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+
        StructField($(labelCol), DoubleType, true))
    }
  }

  @Since("1.5.0")
  override def copy(extra: ParamMap): RFormula = defaultCopy(extra)

  @Since("2.0.0")
  override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)"
}

val resolvedFormula = parsedFormula.resolve(dataset.schema)代码中ParsedRFormula类的resolve()函数的源码如下:

private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) {
  /**
   * Resolves formula terms into column names. A schema is necessary for inferring the meaning
   * of the special '.' term. Duplicate terms will be removed during resolution.
   */
  def resolve(schema: StructType): ResolvedRFormula = {
    val dotTerms = expandDot(schema)
    var includedTerms = Seq[Seq[String]]()
    terms.foreach {
      case col: ColumnRef =>
        includedTerms :+= Seq(col.value)
      case ColumnInteraction(cols) =>
        includedTerms ++= expandInteraction(schema, cols)
      case Dot =>
        includedTerms ++= dotTerms.map(Seq(_))
      case Deletion(term: Term) =>
        term match {
          case inner: ColumnRef =>
            includedTerms = includedTerms.filter(_ != Seq(inner.value))
          case ColumnInteraction(cols) =>
            val fromInteraction = expandInteraction(schema, cols).map(_.toSet)
            includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet))
          case Dot =>
            // e.g. "- .", which removes all first-order terms
            includedTerms = includedTerms.filter {
              case Seq(t) => !dotTerms.contains(t)
              case _ => true
            }
          case _: Deletion =>
            throw new RuntimeException("Deletion terms cannot be nested")
          case _: Intercept =>
        }
      case _: Intercept =>
    }
    ResolvedRFormula(label.value, includedTerms.distinct, hasIntercept)
  }

结果解释

在这里插入图片描述

label其实就是r表达式中 ~前面的clicked字段取值,
feature就是由~后面字段构成的向量,我们看features中第一个列字段取值,其实就是对应了country字段去值,只不过把US CA分别有0 和 1 数字代替了,
第二列取值其实就是hour的取值,第三列其实就是count字段取值。

那么我们现在增加几个维度:
在这里插入图片描述

我们会发现features中有6个维度,其中前四个维度代表的是country的不同取值,第五个维度是hou r取值,第六个维度是count取值,我们看id=12的features,会发现出现了稀疏向量,其实就是向量大小是6,然后【4,5】代表第5 6个位置对应的值分别为2.0 99.0

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值