由于业务需求,需要对多离散列进行OneHotEncoder编码,并扩展为N列,比如有一列性别sex列,只有两种情况:男male,女female,则将这一列横向扩展为两列:sex_male,sex_female,并追加到原数据中,如果某一列特征项多于设定的最大特征,比如100,则不对该列编码。
- 首先是获取需要转换的离散列
val disColsList: List[String] = configMap.getOrElse(OneHotEncoder.DIS_COLS, null).asInstanceOf[List[String]]
var maxCates: Int = configMap.getOrElse(OneHotEncoder.MAX_CATES, 100).toString.toDouble.intValue()
if (disColsList == null || disColsList.size == 0) throw new Exception("离散列为空!")
- 校验这些列是否存在空值
//校验空值
val nullValueCount = NullValueCheck.countNullValue(inputDF, userSelectCols)
if (nullValueCount > 0) {
throw new Exception("输入数据有" + nullValueCount + "条存在空值!")
}
- 对各列计数
//多于用户设置的个数上限,则不编码
val aggMap = userSelectCols.map((_ -> "approx_count_distinct")).toMap
val aggRows = inputDF.agg(aggMap).take(1)
import spark.implicits._
//各列的计数
val colsCountTuple: Array[(String, Long)] = userSelectCols.map(col => (col, aggRows(0).getAs[Long]("approx_count_distinct(" + col + ")")))
- 选出需要转换的列
//不参与编码的列
val colsNotEncoded = colsCountTuple.filter(_._2 > maxCates)
log("不参与编码的列有:" + colsNotEncoded.mkString(","), cmpt)
//不参与编码的列的dataset
val noEncodedDF = spark.createDataFrame(colsNotEncoded)
.toDF("col_name", "distinct_count")
//参与编码的列
val transCols = colsCountTuple.filter(_._2 <= maxCates).map(_._1)
log("参与编码的列有:" + transCols.mkString(","), cmpt)
- 对需要转换的列编码
//离散列编码
val (labelEncodeDF, inputCols) = encode(inputDF, transCols)
- 对多列进行onehot
//one-hot编码
val estimator = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(inputCols.map(_ + "_encoded")).setDropLast(false)
val oneHotDF = estimator.fit(labelEncodeDF).transform(labelEncodeDF).drop(inputCols: _*).cache()
println(oneHotDF.count())
- 对生成的结果最后一列是vector矩阵列进行转换,扩展为多列
//获取各列不同值的统计数据
val colValueCountMap = oneHotDF.rdd.treeAggregate(new LinkedHashMap[String, LinkedHashMap[String, Long]])(apply, merge)
//将vector转为arr的udf
val vecToArray = udf((xs: SparseVector) => xs.toArray)
//一次性转换所有向量列
val vecToArrExprs: Array[Column] = transCols.map(colName => {
//vector列名
val vectorColName = colName + "_indexed_encoded"
vecToArray($"$vectorColName").alias(colName + "_arr_tmp")
})
val orginalCols = inputDF.schema.fieldNames.map(colName => $"$colName")
//将vector转为arr后的dataframe
val transArrDF = oneHotDF.select((orginalCols ++ vecToArrExprs): _*)
//一次性扩展所有arr列
val extendExprs: Array[Column] = transCols.flatMap(colName => {
//arr列名
val arrColName = colName + "_arr_tmp"
//获取每列的扩展字段名数组
val newColNameArr = colValueCountMap.get(colName).get.toArray.sortBy(_._2).reverse.map(colName + "_if_" + _._1)
newColNameArr.zipWithIndex.map { case (newColName, index) => {
$"$arrColName".getItem(index).alias(newColName)
}
}
})
- 追加到原数据中返回
//扩展后的最终的dataframe
val finalDF = transArrDF.select((orginalCols ++ extendExprs): _*)
(noEncodedDF, finalDF)
最后组件的总体代码如下:
/**
* @Author: TheBigBlue
* @Description: 使用approx_count_distinct控制分类个数限制,全部使用selectExpr扩展列
* @Date: 2019/2/13
* @param spark :
* @param inputDF :
* @param configMap :
* @Return:
**/
def invokeOneHot(spark: SparkSession, inputDF: DataFrame, configMap: Map[String, Any]): (DataFrame, DataFrame) = {
val disColsList: List[String] = configMap.getOrElse(OneHotEncoder.DIS_COLS, null).asInstanceOf[List[String]]
var maxCates: Int = configMap.getOrElse(OneHotEncoder.MAX_CATES, 100).toString.toDouble.intValue()
if (disColsList == null || disColsList.size == 0) throw new Exception("离散列为空!")
//控制最大类别的上下限
maxCates = if(maxCates < 2) 2 else if(maxCates > 1000) 1000 else maxCates
val userSelectCols = disColsList.toArray
inputDF.cache()
//校验空值
val nullValueCount = NullValueCheck.countNullValue(inputDF, userSelectCols)
if (nullValueCount > 0) {
throw new Exception("输入数据有" + nullValueCount + "条存在空值!")
}
//多于用户设置的个数上限,则不编码
val aggMap = userSelectCols.map((_ -> "approx_count_distinct")).toMap
val aggRows = inputDF.agg(aggMap).take(1)
import spark.implicits._
//各列的计数
val colsCountTuple: Array[(String, Long)] = userSelectCols.map(col => (col, aggRows(0).getAs[Long]("approx_count_distinct(" + col + ")")))
//不参与编码的列
val colsNotEncoded = colsCountTuple.filter(_._2 > maxCates)
log("不参与编码的列有:" + colsNotEncoded.mkString(","), cmpt)
//不参与编码的列的dataset
val noEncodedDF = spark.createDataFrame(colsNotEncoded)
.toDF("col_name", "distinct_count")
//参与编码的列
val transCols = colsCountTuple.filter(_._2 <= maxCates).map(_._1)
log("参与编码的列有:" + transCols.mkString(","), cmpt)
//离散列编码
val (labelEncodeDF, inputCols) = encode(inputDF, transCols)
//one-hot编码
val estimator = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(inputCols.map(_ + "_encoded")).setDropLast(false)
val oneHotDF = estimator.fit(labelEncodeDF).transform(labelEncodeDF).drop(inputCols: _*).cache()
println(oneHotDF.count())
//获取各列不同值的统计数据
val colValueCountMap = oneHotDF.rdd.treeAggregate(new LinkedHashMap[String, LinkedHashMap[String, Long]])(apply, merge)
//将vector转为arr的udf
val vecToArray = udf((xs: SparseVector) => xs.toArray)
//一次性转换所有向量列
val vecToArrExprs: Array[Column] = transCols.map(colName => {
//vector列名
val vectorColName = colName + "_indexed_encoded"
vecToArray($"$vectorColName").alias(colName + "_arr_tmp")
})
val orginalCols = inputDF.schema.fieldNames.map(colName => $"$colName")
//将vector转为arr后的dataframe
val transArrDF = oneHotDF.select((orginalCols ++ vecToArrExprs): _*)
//一次性扩展所有arr列
val extendExprs: Array[Column] = transCols.flatMap(colName => {
//arr列名
val arrColName = colName + "_arr_tmp"
//获取每列的扩展字段名数组
val newColNameArr = colValueCountMap.get(colName).get.toArray.sortBy(_._2).reverse.map(colName + "_if_" + _._1)
newColNameArr.zipWithIndex.map { case (newColName, index) => {
$"$arrColName".getItem(index).alias(newColName)
}
}
})
//扩展后的最终的dataframe
val finalDF = transArrDF.select((orginalCols ++ extendExprs): _*)
(noEncodedDF, finalDF)
}
/**
* @Author: TheBigBlue
* @Description: 各分区计数
* @Date: 2019/2/13
* @param map :
* @param row :
* @Return:
**/
def apply(map: LinkedHashMap[String, LinkedHashMap[String, Long]], row: Row): LinkedHashMap[String, LinkedHashMap[String, Long]] = {
//统计需要转换的列
row.schema.fields.filter(_.dataType.typeName == "vector").foreach(field => {
//原列名
val originalColName = field.name.substring(0, field.name.indexOf("_indexed_encoded"))
//该列的值
val colValue = row.getAs(originalColName).toString
//计数
if (map.getOrElse(originalColName, null) == null) {
val countMap = new LinkedHashMap[String, Long]()
countMap.put(colValue, 1)
map.put(originalColName, countMap)
} else {
val countMap = map.get(originalColName).get
if (countMap.getOrElse(colValue, null) == null) {
countMap.put(colValue, 1)
} else {
countMap.put(colValue, countMap.get(colValue).get + 1)
}
}
})
map
}
/**
* @Author: TheBigBlue
* @Description: 分区总计数
* @Date: 2019/2/13
* @param map1 :
* @param map2 :
* @Return:
**/
def merge(map1: LinkedHashMap[String, LinkedHashMap[String, Long]],
map2: LinkedHashMap[String, LinkedHashMap[String, Long]]): LinkedHashMap[String, LinkedHashMap[String, Long]] = {
map1 ++= map2
}
/**
* @Author: TheBigBlue
* @Description: 对离散列编码
* @Date: 2019/2/13
* @param inputDF :
* @param userSelectCols :
* @Return:
**/
def encode(inputDF: DataFrame, userSelectCols: Array[String]): (DataFrame, Array[String]) = {
//做转换
val inputCols = new ArrayBuffer[String]()
//使用pipeline一次转换
val indexers = userSelectCols.map(colName => {
val transColName = colName + "_indexed"
inputCols += transColName
new StringIndexer().setInputCol(colName).setOutputCol(transColName)
})
val labelEncodeDF = new Pipeline().setStages(indexers).fit(inputDF).transform(inputDF).cache()
(labelEncodeDF, inputCols.toArray)
}
/**
* @Author: TheBigBlue
* @Description: 校验dataframe中相应的字段数据中是否有空值情况
* @Date: 2019/1/4
* @param dataFrame: 输入的dataframe
* @param colsArr: 需要校验的字段数组
* @return: long 为空的数据条数
**/
def countNullValue(dataFrame: DataFrame, colsArr: Array[String]): Long = {
var nullValueCount: Long = 0
if(dataFrame != null && colsArr != null){
val condition = colsArr.map(colName => "`" + colName + "` is null").mkString(" or ")
nullValueCount = dataFrame.filter(condition).count()
}
nullValueCount
}