Spark -- OneHotEncoder,Spark的实现

  由于业务需求,需要对多离散列进行OneHotEncoder编码,并扩展为N列,比如有一列性别sex列,只有两种情况:男male,女female,则将这一列横向扩展为两列:sex_male,sex_female,并追加到原数据中,如果某一列特征项多于设定的最大特征,比如100,则不对该列编码。

  1. 首先是获取需要转换的离散列
  val disColsList: List[String] = configMap.getOrElse(OneHotEncoder.DIS_COLS, null).asInstanceOf[List[String]]
  var maxCates: Int = configMap.getOrElse(OneHotEncoder.MAX_CATES, 100).toString.toDouble.intValue()
  if (disColsList == null || disColsList.size == 0) throw new Exception("离散列为空!")
  1. 校验这些列是否存在空值
//校验空值
    val nullValueCount = NullValueCheck.countNullValue(inputDF, userSelectCols)
    if (nullValueCount > 0) {
      throw new Exception("输入数据有" + nullValueCount + "条存在空值!")
    }
  1. 对各列计数
//多于用户设置的个数上限,则不编码
    val aggMap = userSelectCols.map((_ -> "approx_count_distinct")).toMap
    val aggRows = inputDF.agg(aggMap).take(1)

    import spark.implicits._
    //各列的计数
    val colsCountTuple: Array[(String, Long)] = userSelectCols.map(col => (col, aggRows(0).getAs[Long]("approx_count_distinct(" + col + ")")))
  1. 选出需要转换的列
//不参与编码的列
    val colsNotEncoded = colsCountTuple.filter(_._2 > maxCates)
    log("不参与编码的列有:" + colsNotEncoded.mkString(","), cmpt)
    //不参与编码的列的dataset
    val noEncodedDF = spark.createDataFrame(colsNotEncoded)
      .toDF("col_name", "distinct_count")

    //参与编码的列
    val transCols = colsCountTuple.filter(_._2 <= maxCates).map(_._1)
    log("参与编码的列有:" + transCols.mkString(","), cmpt)

  1. 对需要转换的列编码
//离散列编码
    val (labelEncodeDF, inputCols) = encode(inputDF, transCols)
  1. 对多列进行onehot
//one-hot编码
    val estimator = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(inputCols.map(_ + "_encoded")).setDropLast(false)
    val oneHotDF = estimator.fit(labelEncodeDF).transform(labelEncodeDF).drop(inputCols: _*).cache()
    println(oneHotDF.count())
  1. 对生成的结果最后一列是vector矩阵列进行转换,扩展为多列
//获取各列不同值的统计数据
    val colValueCountMap = oneHotDF.rdd.treeAggregate(new LinkedHashMap[String, LinkedHashMap[String, Long]])(apply, merge)
    //将vector转为arr的udf
    val vecToArray = udf((xs: SparseVector) => xs.toArray)
    //一次性转换所有向量列
    val vecToArrExprs: Array[Column] = transCols.map(colName => {
      //vector列名
      val vectorColName = colName + "_indexed_encoded"
      vecToArray($"$vectorColName").alias(colName + "_arr_tmp")
    })
    val orginalCols = inputDF.schema.fieldNames.map(colName => $"$colName")
    //将vector转为arr后的dataframe
    val transArrDF = oneHotDF.select((orginalCols ++ vecToArrExprs): _*)

    //一次性扩展所有arr列
    val extendExprs: Array[Column] = transCols.flatMap(colName => {
      //arr列名
      val arrColName = colName + "_arr_tmp"
      //获取每列的扩展字段名数组
      val newColNameArr = colValueCountMap.get(colName).get.toArray.sortBy(_._2).reverse.map(colName + "_if_" + _._1)
      newColNameArr.zipWithIndex.map { case (newColName, index) => {
        $"$arrColName".getItem(index).alias(newColName)
      }
      }
    })
  1. 追加到原数据中返回
//扩展后的最终的dataframe
    val finalDF = transArrDF.select((orginalCols ++ extendExprs): _*)
    (noEncodedDF, finalDF)

  最后组件的总体代码如下:

/**
    * @Author: TheBigBlue
    * @Description: 使用approx_count_distinct控制分类个数限制,全部使用selectExpr扩展列
    * @Date: 2019/2/13
    * @param spark     :
    * @param inputDF   :
    * @param configMap :
    * @Return:
    **/
  def invokeOneHot(spark: SparkSession, inputDF: DataFrame, configMap: Map[String, Any]): (DataFrame, DataFrame) = {
    val disColsList: List[String] = configMap.getOrElse(OneHotEncoder.DIS_COLS, null).asInstanceOf[List[String]]
    var maxCates: Int = configMap.getOrElse(OneHotEncoder.MAX_CATES, 100).toString.toDouble.intValue()
    if (disColsList == null || disColsList.size == 0) throw new Exception("离散列为空!")
    //控制最大类别的上下限
    maxCates = if(maxCates < 2) 2 else if(maxCates > 1000) 1000 else maxCates
    val userSelectCols = disColsList.toArray
    inputDF.cache()
    //校验空值
    val nullValueCount = NullValueCheck.countNullValue(inputDF, userSelectCols)
    if (nullValueCount > 0) {
      throw new Exception("输入数据有" + nullValueCount + "条存在空值!")
    }

    //多于用户设置的个数上限,则不编码
    val aggMap = userSelectCols.map((_ -> "approx_count_distinct")).toMap
    val aggRows = inputDF.agg(aggMap).take(1)

    import spark.implicits._
    //各列的计数
    val colsCountTuple: Array[(String, Long)] = userSelectCols.map(col => (col, aggRows(0).getAs[Long]("approx_count_distinct(" + col + ")")))
    //不参与编码的列
    val colsNotEncoded = colsCountTuple.filter(_._2 > maxCates)
    log("不参与编码的列有:" + colsNotEncoded.mkString(","), cmpt)
    //不参与编码的列的dataset
    val noEncodedDF = spark.createDataFrame(colsNotEncoded)
      .toDF("col_name", "distinct_count")

    //参与编码的列
    val transCols = colsCountTuple.filter(_._2 <= maxCates).map(_._1)
    log("参与编码的列有:" + transCols.mkString(","), cmpt)

    //离散列编码
    val (labelEncodeDF, inputCols) = encode(inputDF, transCols)

    //one-hot编码
    val estimator = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(inputCols.map(_ + "_encoded")).setDropLast(false)
    val oneHotDF = estimator.fit(labelEncodeDF).transform(labelEncodeDF).drop(inputCols: _*).cache()
    println(oneHotDF.count())
    //获取各列不同值的统计数据
    val colValueCountMap = oneHotDF.rdd.treeAggregate(new LinkedHashMap[String, LinkedHashMap[String, Long]])(apply, merge)
    //将vector转为arr的udf
    val vecToArray = udf((xs: SparseVector) => xs.toArray)
    //一次性转换所有向量列
    val vecToArrExprs: Array[Column] = transCols.map(colName => {
      //vector列名
      val vectorColName = colName + "_indexed_encoded"
      vecToArray($"$vectorColName").alias(colName + "_arr_tmp")
    })
    val orginalCols = inputDF.schema.fieldNames.map(colName => $"$colName")
    //将vector转为arr后的dataframe
    val transArrDF = oneHotDF.select((orginalCols ++ vecToArrExprs): _*)

    //一次性扩展所有arr列
    val extendExprs: Array[Column] = transCols.flatMap(colName => {
      //arr列名
      val arrColName = colName + "_arr_tmp"
      //获取每列的扩展字段名数组
      val newColNameArr = colValueCountMap.get(colName).get.toArray.sortBy(_._2).reverse.map(colName + "_if_" + _._1)
      newColNameArr.zipWithIndex.map { case (newColName, index) => {
        $"$arrColName".getItem(index).alias(newColName)
      }
      }
    })
    //扩展后的最终的dataframe
    val finalDF = transArrDF.select((orginalCols ++ extendExprs): _*)
    (noEncodedDF, finalDF)
  }

  /**
    * @Author: TheBigBlue
    * @Description: 各分区计数
    * @Date: 2019/2/13
    * @param map :
    * @param row :
    * @Return:
    **/
  def apply(map: LinkedHashMap[String, LinkedHashMap[String, Long]], row: Row): LinkedHashMap[String, LinkedHashMap[String, Long]] = {
    //统计需要转换的列
    row.schema.fields.filter(_.dataType.typeName == "vector").foreach(field => {
      //原列名
      val originalColName = field.name.substring(0, field.name.indexOf("_indexed_encoded"))
      //该列的值
      val colValue = row.getAs(originalColName).toString
      //计数
      if (map.getOrElse(originalColName, null) == null) {
        val countMap = new LinkedHashMap[String, Long]()
        countMap.put(colValue, 1)
        map.put(originalColName, countMap)
      } else {
        val countMap = map.get(originalColName).get
        if (countMap.getOrElse(colValue, null) == null) {
          countMap.put(colValue, 1)
        } else {
          countMap.put(colValue, countMap.get(colValue).get + 1)
        }
      }
    })
    map
  }

  /**
    * @Author: TheBigBlue
    * @Description: 分区总计数
    * @Date: 2019/2/13
    * @param map1 :
    * @param map2 :
    * @Return:
    **/
  def merge(map1: LinkedHashMap[String, LinkedHashMap[String, Long]],
            map2: LinkedHashMap[String, LinkedHashMap[String, Long]]): LinkedHashMap[String, LinkedHashMap[String, Long]] = {
    map1 ++= map2
  }

  /**
    * @Author: TheBigBlue
    * @Description: 对离散列编码
    * @Date: 2019/2/13
    * @param inputDF        :
    * @param userSelectCols :
    * @Return:
    **/
  def encode(inputDF: DataFrame, userSelectCols: Array[String]): (DataFrame, Array[String]) = {
    //做转换
    val inputCols = new ArrayBuffer[String]()
    //使用pipeline一次转换
    val indexers = userSelectCols.map(colName => {
      val transColName = colName + "_indexed"
      inputCols += transColName
      new StringIndexer().setInputCol(colName).setOutputCol(transColName)
    })
    val labelEncodeDF = new Pipeline().setStages(indexers).fit(inputDF).transform(inputDF).cache()
    (labelEncodeDF, inputCols.toArray)
  }

 /**
   * @Author: TheBigBlue
   * @Description: 校验dataframe中相应的字段数据中是否有空值情况
   * @Date: 2019/1/4
   * @param dataFrame: 输入的dataframe
   * @param colsArr: 需要校验的字段数组
   * @return: long 为空的数据条数
   **/
  def countNullValue(dataFrame: DataFrame, colsArr: Array[String]): Long = {
    var nullValueCount: Long = 0
    if(dataFrame != null && colsArr != null){
      val condition = colsArr.map(colName => "`" + colName + "` is null").mkString(" or ")
      nullValueCount = dataFrame.filter(condition).count()
    }
    nullValueCount
  }
  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值