【spark ML系列】Spark MLlib中的Bucketizer场景用法示例源码解析

Spark MLlib中的Bucketizer源码解析

1. 适用场景

Bucketizer将连续特征的列映射到特征桶的列。

该源码适用于以下场景:

  • 将连续特征划分为离散的桶,以进行进一步的处理或建模。
  • 处理包含连续特征的数据集,并将其转换为具有桶特征的新数据集。
  • 可以同时映射多个列,适用于批量处理多个连续特征的情况。

2. 多种主要用法及其代码示例

使用单个列进行映射

import org.apache.spark.ml.feature.Bucketizer

// 准备输入数据集
val data = Array(-0.5, 0.1, 1.5, 2.0)
val dataFrame = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features")

// 设置桶划分点
val splits = Array(-Double.PositiveInfinity, 0.0, 1.0, Double.PositiveInfinity)

// 创建Bucketizer对象并设置参数
val bucketizer = new Bucketizer()
  .setInputCol("features")
  .setOutputCol("bucketedFeatures")
  .setSplits(splits)

// 应用Bucketizer进行转换
val bucketedData = bucketizer.transform(dataFrame)
bucketedData.show()

输出结果:

+--------+----------------+
|features|bucketedFeatures|
+--------+----------------+
|    -0.5|             0.0|
|     0.1|             1.0|
|     1.5|             2.0|
|     2.0|             2.0|
+--------+----------------+

使用多个列进行映射

import org.apache.spark.ml.feature.Bucketizer

// 准备输入数据集
val data = Array((-0.5, 10.0), (0.1, 20.0), (1.5, 30.0), (2.0, 40.0))
val dataFrame = spark.createDataFrame(data).toDF("features1", "features2")

// 设置桶划分点
val splitsArray = Array(
  Array(-Double.PositiveInfinity, 0.0, Double.PositiveInfinity),
  Array(0.0, 1.0, 2.0, Double.PositiveInfinity)
)

// 创建Bucketizer对象并设置参数
val bucketizer = new Bucketizer()
  .setInputCols(Array("features1", "features2"))
  .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
  .setSplitsArray(splitsArray)

// 应用Bucketizer进行转换
val bucketedData = bucketizer.transform(dataFrame)
bucketedData.show()

输出结果:

+---------+---------+-------------------+-------------------+
|features1|features2|bucketedFeatures1  |bucketedFeatures2  |
+---------+---------+-------------------+-------------------+
|-0.5     |10.0     |0.0                |1.0                |
|0.1      |20.0     |1.0                |2.0                |
|1.5      |30.0     |2.0                |3.0                |
|2.0      |40.0     |2.0                |4.0                |
+---------+---------+-------------------+-------------------+

3. 中文源码

/**
 * `Bucketizer`将连续特征的列映射到特征桶的列。
 *
 * 自2.3.0版本起,
 * `Bucketizer`可以通过设置`inputCols`参数一次性映射多个列。注意,当`inputCol`和`inputCols`参数都设置时,将抛出异常。
 * `splits`参数仅用于单列使用,`splitsArray`用于多列。
 */
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
  extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
    with HasInputCols with HasOutputCols with DefaultParamsWritable {

  @Since("1.4.0")
  def this() = this(Identifiable.randomUID("bucketizer"))

  /**
   * 将连续特征映射为桶的参数。对于n+1个分割点,有n个桶。
   * 由分割点x、y定义的桶包含范围[x,y),最后一个桶也包括y。分割点应具有大于或等于3个且严格递增的长度。
   * 需要明确提供-inf、inf处的值以覆盖所有Double值;否则,超出指定分割点范围的值将被视为错误。
   *
   * 另请参见[[handleInvalid]],它可以选择为NaN值创建额外的桶。
   *
   * @group param
   */
  @Since("1.4.0")
  val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
    "用于将连续特征映射到桶中的分割点。对于n+1个分割点,有n个桶。由分割点x、y定义的桶包含范围[x,y),最后一个桶也包括y。" +
      "分割点应具有长度>=3且严格递增。需要明确提供-inf、inf处的值以覆盖所有Double值;否则,超出指定分割点范围的值将被视为错误。",
    Bucketizer.checkSplits)

  /** @group getParam */
  @Since("1.4.0")
  def getSplits: Array[Double] = $(splits)

  /** @group setParam */
  @Since("1.4.0")
  def setSplits(value: Array[Double]): this.type = set(splits, value)

  /** @group setParam */
  @Since("1.4.0")
  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */
  @Since("1.4.0")
  def setOutputCol(value: String): this.type = set(outputCol, value)

  /**
   * 参数用于处理无效条目的方式。选项有'skip'(过滤掉具有无效值的行),
   * 'error'(抛出错误)或'keep'(将无效值保留在特殊的附加桶中)。
   * 注意,在多列情况下,无效处理适用于所有列。对于'error',如果任何列中存在无效值,它将抛出错误;
   * 对于'skip',如果任何列中存在无效值,它将跳过具有任何无效值的行等等。
   * 默认值:"error"
   * @group param
   */
  @Since("2.1.0")
  override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
    "如何处理无效条目。选项为skip(过滤掉具有无效值的行),error(抛出错误)或keep(将无效值保留在特殊的附加桶中)。",
    ParamValidators.inArray(Bucketizer.supportedHandleInvalids))

  /** @group setParam */
  @Since("2.1.0")
  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
  setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

  /**
   * 用于指定多个分割参数的参数。此数组中的每个元素都可以用于将连续特征映射到桶中。
   *
   * @group param
   */
  @Since("2.3.0")
  val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
    "映射连续特征到多列桶的分割点数组。对于每个输入列,n+1个分割点将产生n个桶。由分割点x、y定义的桶包含范围[x,y)," +
      "最后一个桶也包括y。分割点应具有长度>=3且严格递增。需要明确提供-inf、inf处的值以覆盖所有Double值;" +
      "否则,超出指定分割点范围的值将被视为错误。",
    Bucketizer.checkSplitsArray)

  /** @group getParam */
  @Since("2.3.0")
  def getSplitsArray: Array[Array[Double]] = $(splitsArray)

  /** @group setParam */
  @Since("2.3.0")
  def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)

  /** @group setParam */
  @Since("2.3.0")
  def setInputCols(value: Array[String]): this.type = set(inputCols, value)

  /** @group setParam */
  @Since("2.3.0")
  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    val transformedSchema = transformSchema(dataset.schema)

    val (inputColumns, outputColumns) = if (isSet(inputCols)) {
      ($(inputCols).toSeq, $(outputCols).toSeq)
    } else {
      (Seq($(inputCol)), Seq($(outputCol)))
    }

    val (filteredDataset, keepInvalid) = {
      if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
        // 如果设置了“skip” NaN选项,则过滤掉数据集中的NaN值
        (dataset.na.drop(inputColumns).toDF(), false)
      } else {
        (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
      }
    }

    val seqOfSplits = if (isSet(inputCols)) {
      $(splitsArray).toSeq
    } else {
      Seq($(splits))
    }

    val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
      udf { (feature: Double) =>
        Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
      }.withName(s"bucketizer_$idx")
    }

    val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
      bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
    }
    val metadata = outputColumns.map { col =>
      transformedSchema(col).metadata
    }
    filteredDataset.withColumns(outputColumns, newCols, metadata)
  }

  private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
    val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
    val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
      values = Some(buckets))
    attr.toStructField()
  }

  @Since("1.4.0")
  override def transformSchema(schema: StructType): StructType = {
    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol, splits),
      Seq(outputCols, splitsArray))

    if (isSet(inputCols)) {
      require(getInputCols.length == getOutputCols.length &&
        getInputCols.length == getSplitsArray.length, s"Bucketizer $this has mismatched Params " +
        s"for multi-column transform.  Params (inputCols, outputCols, splitsArray) should have " +
        s"equal lengths, but they have different lengths: " +
        s"(${getInputCols.length}, ${getOutputCols.length}, ${getSplitsArray.length}).")

      var transformedSchema = schema
      $(inputCols).zip($(outputCols)).zipWithIndex.foreach { case ((inputCol, outputCol), idx) =>
        SchemaUtils.checkNumericType(transformedSchema, inputCol)
        transformedSchema = SchemaUtils.appendColumn(transformedSchema,
          prepOutputField($(splitsArray)(idx), outputCol))
      }
      transformedSchema
    } else {
      SchemaUtils.checkNumericType(schema, $(inputCol))
      SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
    }
  }

  @Since("1.4.1")
  override def copy(extra: ParamMap): Bucketizer = {
    defaultCopy[Bucketizer](extra).setParent(parent)
  }
}

@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {

  private[feature] val SKIP_INVALID: String = "skip"
  private[feature] val ERROR_INVALID: String = "error"
  private[feature] val KEEP_INVALID: String = "keep"
  private[feature] val supportedHandleInvalids: Array[String] =
    Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)

  /**
   * 要求分割点的长度大于等于3,并且严格递增。
   * 不应接受NaN分割点。
   */
  private[feature] def checkSplits(splits: Array[Double]): Boolean = {
    if (splits.length < 3) {
      false
    } else {
      var i = 0
      val n = splits.length - 1
      while (i < n) {
        if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
        i += 1
      }
      !splits(n).isNaN
    }
  }

  /**
   * 检查分割点数组中的每个分割点。
   */
  private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
    splitsArray.forall(checkSplits(_))
  }

  /**
   * 在多个桶中进行二分搜索以将每个数据点放置到相应的桶中。
   * @param splits 分割点数组
   * @param feature 数据点
   * @param keepInvalid NaN标志。
   *                    设置为"true"以为NaN值创建一个额外的桶;
   *                    设置为"false"以报告NaN值的错误
   * @return 每个数据点的桶
   * @throws SparkException 如果特征值<分割点.head或>分割点.last
   */

  private[feature] def binarySearchForBuckets(
      splits: Array[Double],
      feature: Double,
      keepInvalid: Boolean): Double = {
    if (feature.isNaN) {
      if (keepInvalid) {
        splits.length - 1
      } else {
        throw new SparkException("Bucketizer遇到NaN值。要处理或跳过NaN值,请尝试设置Bucketizer.handleInvalid。")
      }
    } else if (feature == splits.last) {
      splits.length - 2
    } else {
      val idx = ju.Arrays.binarySearch(splits, feature)
      if (idx >= 0) {
        idx
      } else {
        val insertPos = -idx - 1
        if (insertPos == 0 || insertPos == splits.length) {
          throw new SparkException(s"特征值 $feature 超出Bucketizer范围[${splits.head}, ${splits.last}]。" +
            s"请检查您的特征值或放宽下限/上限约束。")
        } else {
          insertPos - 1
        }
      }
    }
  }

  @Since("1.6.0")
  override def load(path: String): Bucketizer = super.load(path)
}
​```

4. 官方链接

Bucketizer - Apache Spark

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值