One-hot
/**
* One-hot encoding example function
* @param samples movie samples dataframe
*/
def oneHotEncoderExample(samples:DataFrame): Unit ={
val samplesWithIdNumber = samples.withColumn("movieIdNumber", col("movieId").cast(sql.types.IntegerType))
val oneHotEncoder = new OneHotEncoderEstimator()
.setInputCols(Array("movieIdNumber"))
.setOutputCols(Array("movieIdVector"))
.setDropLast(false)
val oneHotEncoderSamples = oneHotEncoder.fit(samplesWithIdNumber).transform(samplesWithIdNumber)
oneHotEncoderSamples.printSchema()
oneHotEncoderSamples.show(10)
}
Multi-hot
val array2vec: UserDefinedFunction = udf { (a: Seq[Int], length: Int) => org.apache.spark.ml.linalg.Vectors.sparse(length, a.sortWith(_ < _).toArray, Array.fill[Double](a.length)(1.0)) }
/**
* Multi-hot encoding example function
* @param samples movie samples dataframe
*/
def multiHotEncoderExample(samples:DataFrame): Unit ={
val samplesWithGenre = samples.select(col("movieId"), col("title"),explode(split(col("genres"), "\\|").cast("array<string>")).as("genre"))
val genreIndexer = new StringIndexer().setInputCol("genre").setOutputCol("genreIndex")
val stringIndexerModel : StringIndexerModel = genreIndexer.fit(samplesWithGenre)
val genreIndexSamples = stringIndexerModel.transform(samplesWithGenre)
.withColumn("genreIndexInt", col("genreIndex").cast(sql.types.IntegerType))
/* println("genreIndexSamples:")
genreIndexSamples.printSchema()
genreIndexSamples.show(10,false)
println("genreIndexSamples.agg:")
genreIndexSamples.agg(max(col("genreIndexInt"))).show(10,false)*/
val indexSize = genreIndexSamples.agg(max(col("genreIndexInt"))).head().getAs[Int](0) + 1
val processedSamples = genreIndexSamples
.groupBy(col("movieId")).agg(collect_list("genreIndexInt").as("genreIndexes"))
.withColumn("indexSize", typedLit(indexSize))
val finalSample = processedSamples.withColumn("vector", array2vec(col("genreIndexes"),col("indexSize")))
finalSample.printSchema()
finalSample.show(10,false)
}
注释:
StringIndexer的使用
lit和typeLit
collect_list
agg使用
spark的聚合函数
输出样例:
one-hot
Raw Movie Samples:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genres: string (nullable = true)
+-------+--------------------+--------------------+
|movieId| title| genres|
+-------+--------------------+--------------------+
| 1| Toy Story (1995)|Adventure|Animati...|
| 2| Jumanji (1995)|Adventure|Childre...|
| 3|Grumpier Old Men ...| Comedy|Romance|
| 4|Waiting to Exhale...|Comedy|Drama|Romance|
| 5|Father of the Bri...| Comedy|
| 6| Heat (1995)|Action|Crime|Thri...|
| 7| Sabrina (1995)| Comedy|Romance|
| 8| Tom and Huck (1995)| Adventure|Children|
| 9| Sudden Death (1995)| Action|
| 10| GoldenEye (1995)|Action|Adventure|...|
+-------+--------------------+--------------------+
only showing top 10 rows
OneHotEncoder Example:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genres: string (nullable = true)
|-- movieIdNumber: integer (nullable = true)
|-- movieIdVector: vector (nullable = true)
+-------+--------------------+--------------------+-------------+-----------------+
|movieId| title| genres|movieIdNumber| movieIdVector|
+-------+--------------------+--------------------+-------------+-----------------+
| 1| Toy Story (1995)|Adventure|Animati...| 1| (1001,[1],[1.0])|
| 2| Jumanji (1995)|Adventure|Childre...| 2| (1001,[2],[1.0])|
| 3|Grumpier Old Men ...| Comedy|Romance| 3| (1001,[3],[1.0])|
| 4|Waiting to Exhale...|Comedy|Drama|Romance| 4| (1001,[4],[1.0])|
| 5|Father of the Bri...| Comedy| 5| (1001,[5],[1.0])|
| 6| Heat (1995)|Action|Crime|Thri...| 6| (1001,[6],[1.0])|
| 7| Sabrina (1995)| Comedy|Romance| 7| (1001,[7],[1.0])|
| 8| Tom and Huck (1995)| Adventure|Children| 8| (1001,[8],[1.0])|
| 9| Sudden Death (1995)| Action| 9| (1001,[9],[1.0])|
| 10| GoldenEye (1995)|Action|Adventure|...| 10|(1001,[10],[1.0])|
+-------+--------------------+--------------------+-------------+-----------------+
multi-hot
MultiHotEncoder Example:
genreIndexSamples:
root
|-- movieId: string (nullable = true)
|-- title: string (nullable = true)
|-- genre: string (nullable = true)
|-- genreIndex: double (nullable = false)
|-- genreIndexInt: integer (nullable = true)
+-------+-----------------------+---------+----------+-------------+
|movieId|title |genre |genreIndex|genreIndexInt|
+-------+-----------------------+---------+----------+-------------+
|1 |Toy Story (1995) |Adventure|6.0 |6 |
|1 |Toy Story (1995) |Animation|15.0 |15 |
|1 |Toy Story (1995) |Children |7.0 |7 |
|1 |Toy Story (1995) |Comedy |1.0 |1 |
|1 |Toy Story (1995) |Fantasy |10.0 |10 |
|2 |Jumanji (1995) |Adventure|6.0 |6 |
|2 |Jumanji (1995) |Children |7.0 |7 |
|2 |Jumanji (1995) |Fantasy |10.0 |10 |
|3 |Grumpier Old Men (1995)|Comedy |1.0 |1 |
|3 |Grumpier Old Men (1995)|Romance |2.0 |2 |
+-------+-----------------------+---------+----------+-------------+
genreIndexSamples.agg:
+------------------+
|max(genreIndexInt)|
+------------------+
|18 |
+------------------+
finalSample:
root
|-- movieId: string (nullable = true)
|-- genreIndexes: array (nullable = true)
| |-- element: integer (containsNull = true)
|-- indexSize: integer (nullable = false)
|-- vector: vector (nullable = true)
+-------+------------+---------+--------------------------------+
|movieId|genreIndexes|indexSize|vector |
+-------+------------+---------+--------------------------------+
|296 |[1, 5, 0, 3]|19 |(19,[0,1,3,5],[1.0,1.0,1.0,1.0])|
|467 |[1] |19 |(19,[1],[1.0]) |
|675 |[4, 0, 3] |19 |(19,[0,3,4],[1.0,1.0,1.0]) |
|691 |[1, 2] |19 |(19,[1,2],[1.0,1.0]) |
|829 |[1, 10, 14] |19 |(19,[1,10,14],[1.0,1.0,1.0]) |
|125 |[1] |19 |(19,[1],[1.0]) |
|451 |[0, 8, 2] |19 |(19,[0,2,8],[1.0,1.0,1.0]) |
|800 |[0, 8, 16] |19 |(19,[0,8,16],[1.0,1.0,1.0]) |
|853 |[0] |19 |(19,[0],[1.0]) |
|944 |[0] |19 |(19,[0],[1.0]) |
+-------+------------+---------+--------------------------------+
另一种multi-hot方法(适合标签量不大的情况)
主要是靠获取getWordsIndexMap,然后做映射
def getWordsIndexMap(rdd: RDD[Set[String]], ss: SparkSession): Broadcast[Map[String, Int]] = {
val allWords = rdd.map { x => (1, x) }.reduceByKey((x, y) => x ++ y).collect().head._2.toArray.sorted
val wordsMapbt = ss.sparkContext.broadcast(allWords.zip(0.until(allWords.length)).toMap)
wordsMapbt
}
def transformVec(rdd: RDD[(String, Set[String], String)], ss: SparkSession, mp: Broadcast[Map[String, Int]]) = {
import ss.sqlContext.implicits._
val indexDF = rdd.map { x => x._1 }.distinct().zipWithUniqueId().toDF("id", "index")
val outRDD = rdd.toDF("id", "keywords", "from")
.join(indexDF, "id")
.select("index", "keywords", "from")
.rdd
.map {
case Row(index: Long, keywords: collection.mutable.WrappedArray[String], from: String) =>
val len = mp.value.size
val arr1 = keywords.toArray.sorted.map {
x =>
mp.value(x)
}
val arr2 = arr1.map { x => 1.0 }
(index, Vectors.sparse(len, arr1, arr2), from)
}
(indexDF, outRDD)
}