package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
* A feature transformer that merges multiple columns into a vector column.
* 这是一个类,用于合并多个列到一个vector列
*/
@Since("1.4.0")
class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("vecAssembler"))
/** @group setParam */
@Since("1.4.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
/** @group setParam */
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
field.dataType match {
case DoubleType =>
val attr = Attribute.fromStructField(field)
// If the input column doesn't have ML attribute, assume numeric.
if (attr == UnresolvedAttribute) {
Some(NumericAttribute.defaultAttr.withName(c))
} else {
Some(attr.withName(c))
}
case _: NumericType | BooleanType =>
// If the input column type is a compatible scalar type, assume numeric.
Some(NumericAttribute.defaultAttr.withName(c))
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(field)
if (group.attributes.isDefined) {
// If attributes are defined, copy them with updated names.
group.attributes.get.zipWithIndex.map { case (attr, i) =>
if (attr.name.isDefined) {
// TODO: Define a rigorous naming scheme.
attr.withName(c + "_" + attr.name.get)
} else {
attr.withName(c + "_" + i)
}
}
} else {
// Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
// from metadata, check the first row.
val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i))
}
case otherType =>
throw new SparkException(s"VectorAssembler does not support the $otherType type")
}
}
val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
// Data transformation.
val assembleFunc = udf { r: Row =>
VectorAssembler.assemble(r.toSeq: _*)
}
val args = $(inputCols).map { c =>
schema(c).dataType match {
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
inputDataTypes.foreach {
case _: NumericType | BooleanType =>
case t if t.isInstanceOf[VectorUDT] =>
case other =>
throw new IllegalArgumentException(s"Data type $other is not supported.")
}
if (schema.fieldNames.contains(outputColName)) {
throw new IllegalArgumentException(s"Output column $outputColName already exists.")
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true))
}
@Since("1.4.1")
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}
@Since("1.6.0")
object VectorAssembler extends DefaultParamsReadable[VectorAssembler] {
@Since("1.6.0")
override def load(path: String): VectorAssembler = super.load(path)
private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
val values = ArrayBuilder.make[Double]
var cur = 0
vv.foreach {
case v: Double =>
if (v != 0.0) {
indices += cur
values += v
}
cur += 1
case vec: Vector =>
vec.foreachActive { case (i, v) =>
if (v != 0.0) {
indices += cur + i
values += v
}
}
cur += vec.size
case null =>
// TODO: output Double.NaN?
throw new SparkException("Values to assemble cannot be null.")
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result()).compressed
}
}
1.spark.ml.feature.VectorAssembler
最新推荐文章于 2024-01-09 01:18:44 发布