Spark特征工程-one-hot 和 multi-hot

One-hot

  /**
   * One-hot encoding example function
   * @param samples movie samples dataframe
   */
  def oneHotEncoderExample(samples:DataFrame): Unit ={
    val samplesWithIdNumber = samples.withColumn("movieIdNumber", col("movieId").cast(sql.types.IntegerType))

    val oneHotEncoder = new OneHotEncoderEstimator()
      .setInputCols(Array("movieIdNumber"))
      .setOutputCols(Array("movieIdVector"))
      .setDropLast(false)

    val oneHotEncoderSamples = oneHotEncoder.fit(samplesWithIdNumber).transform(samplesWithIdNumber)
    oneHotEncoderSamples.printSchema()
    oneHotEncoderSamples.show(10)
  }

Multi-hot


  val array2vec: UserDefinedFunction = udf { (a: Seq[Int], length: Int) => org.apache.spark.ml.linalg.Vectors.sparse(length, a.sortWith(_ < _).toArray, Array.fill[Double](a.length)(1.0)) }

  /**
   * Multi-hot encoding example function
   * @param samples movie samples dataframe
   */
  def multiHotEncoderExample(samples:DataFrame): Unit ={
    val samplesWithGenre = samples.select(col("movieId"), col("title"),explode(split(col("genres"), "\\|").cast("array<string>")).as("genre"))
    val genreIndexer = new StringIndexer().setInputCol("genre").setOutputCol("genreIndex")

    val stringIndexerModel : StringIndexerModel = genreIndexer.fit(samplesWithGenre)

    val genreIndexSamples = stringIndexerModel.transform(samplesWithGenre)
      .withColumn("genreIndexInt", col("genreIndex").cast(sql.types.IntegerType))

/*    println("genreIndexSamples:")
    genreIndexSamples.printSchema()
    genreIndexSamples.show(10,false)
    println("genreIndexSamples.agg:")
    genreIndexSamples.agg(max(col("genreIndexInt"))).show(10,false)*/
   
    val indexSize = genreIndexSamples.agg(max(col("genreIndexInt"))).head().getAs[Int](0) + 1

    val processedSamples =  genreIndexSamples
      .groupBy(col("movieId")).agg(collect_list("genreIndexInt").as("genreIndexes"))
        .withColumn("indexSize", typedLit(indexSize))

    val finalSample = processedSamples.withColumn("vector", array2vec(col("genreIndexes"),col("indexSize")))
    finalSample.printSchema()
    finalSample.show(10,false)
  }

注释:
StringIndexer的使用
lit和typeLit
collect_list
agg使用
spark的聚合函数

输出样例:
one-hot
Raw Movie Samples:
root
 |-- movieId: string (nullable = true)
 |-- title: string (nullable = true)
 |-- genres: string (nullable = true)

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
+-------+--------------------+--------------------+
only showing top 10 rows

OneHotEncoder Example:
root
 |-- movieId: string (nullable = true)
 |-- title: string (nullable = true)
 |-- genres: string (nullable = true)
 |-- movieIdNumber: integer (nullable = true)
 |-- movieIdVector: vector (nullable = true)

+-------+--------------------+--------------------+-------------+-----------------+
|movieId|               title|              genres|movieIdNumber|    movieIdVector|
+-------+--------------------+--------------------+-------------+-----------------+
|      1|    Toy Story (1995)|Adventure|Animati...|            1| (1001,[1],[1.0])|
|      2|      Jumanji (1995)|Adventure|Childre...|            2| (1001,[2],[1.0])|
|      3|Grumpier Old Men ...|      Comedy|Romance|            3| (1001,[3],[1.0])|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|            4| (1001,[4],[1.0])|
|      5|Father of the Bri...|              Comedy|            5| (1001,[5],[1.0])|
|      6|         Heat (1995)|Action|Crime|Thri...|            6| (1001,[6],[1.0])|
|      7|      Sabrina (1995)|      Comedy|Romance|            7| (1001,[7],[1.0])|
|      8| Tom and Huck (1995)|  Adventure|Children|            8| (1001,[8],[1.0])|
|      9| Sudden Death (1995)|              Action|            9| (1001,[9],[1.0])|
|     10|    GoldenEye (1995)|Action|Adventure|...|           10|(1001,[10],[1.0])|
+-------+--------------------+--------------------+-------------+-----------------+

multi-hot
MultiHotEncoder Example:
genreIndexSamples:
root
 |-- movieId: string (nullable = true)
 |-- title: string (nullable = true)
 |-- genre: string (nullable = true)
 |-- genreIndex: double (nullable = false)
 |-- genreIndexInt: integer (nullable = true)

+-------+-----------------------+---------+----------+-------------+
|movieId|title                  |genre    |genreIndex|genreIndexInt|
+-------+-----------------------+---------+----------+-------------+
|1      |Toy Story (1995)       |Adventure|6.0       |6            |
|1      |Toy Story (1995)       |Animation|15.0      |15           |
|1      |Toy Story (1995)       |Children |7.0       |7            |
|1      |Toy Story (1995)       |Comedy   |1.0       |1            |
|1      |Toy Story (1995)       |Fantasy  |10.0      |10           |
|2      |Jumanji (1995)         |Adventure|6.0       |6            |
|2      |Jumanji (1995)         |Children |7.0       |7            |
|2      |Jumanji (1995)         |Fantasy  |10.0      |10           |
|3      |Grumpier Old Men (1995)|Comedy   |1.0       |1            |
|3      |Grumpier Old Men (1995)|Romance  |2.0       |2            |
+-------+-----------------------+---------+----------+-------------+
 
 
genreIndexSamples.agg:
+------------------+
|max(genreIndexInt)|
+------------------+
|18                |
+------------------+

finalSample:
root
 |-- movieId: string (nullable = true)
 |-- genreIndexes: array (nullable = true)
 |    |-- element: integer (containsNull = true)
 |-- indexSize: integer (nullable = false)
 |-- vector: vector (nullable = true)

+-------+------------+---------+--------------------------------+
|movieId|genreIndexes|indexSize|vector                          |
+-------+------------+---------+--------------------------------+
|296    |[1, 5, 0, 3]|19       |(19,[0,1,3,5],[1.0,1.0,1.0,1.0])|
|467    |[1]         |19       |(19,[1],[1.0])                  |
|675    |[4, 0, 3]   |19       |(19,[0,3,4],[1.0,1.0,1.0])      |
|691    |[1, 2]      |19       |(19,[1,2],[1.0,1.0])            |
|829    |[1, 10, 14] |19       |(19,[1,10,14],[1.0,1.0,1.0])    |
|125    |[1]         |19       |(19,[1],[1.0])                  |
|451    |[0, 8, 2]   |19       |(19,[0,2,8],[1.0,1.0,1.0])      |
|800    |[0, 8, 16]  |19       |(19,[0,8,16],[1.0,1.0,1.0])     |
|853    |[0]         |19       |(19,[0],[1.0])                  |
|944    |[0]         |19       |(19,[0],[1.0])                  |
+-------+------------+---------+--------------------------------+

另一种multi-hot方法(适合标签量不大的情况)
主要是靠获取getWordsIndexMap,然后做映射


  def getWordsIndexMap(rdd: RDD[Set[String]], ss: SparkSession): Broadcast[Map[String, Int]] = {
    val allWords = rdd.map { x => (1, x) }.reduceByKey((x, y) => x ++ y).collect().head._2.toArray.sorted
    val wordsMapbt = ss.sparkContext.broadcast(allWords.zip(0.until(allWords.length)).toMap)
    wordsMapbt
  }
 
  def transformVec(rdd: RDD[(String, Set[String], String)], ss: SparkSession, mp: Broadcast[Map[String, Int]]) = {
    import ss.sqlContext.implicits._
    val indexDF = rdd.map { x => x._1 }.distinct().zipWithUniqueId().toDF("id", "index")
    val outRDD = rdd.toDF("id", "keywords", "from")
      .join(indexDF, "id")
      .select("index", "keywords", "from")
      .rdd
      .map {
        case Row(index: Long, keywords: collection.mutable.WrappedArray[String], from: String) =>
 
          val len = mp.value.size
 
          val arr1 = keywords.toArray.sorted.map {
            x =>
              mp.value(x)
          }
          val arr2 = arr1.map { x => 1.0 }
          (index, Vectors.sparse(len, arr1, arr2), from)
      }
    (indexDF, outRDD)
  }
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值