【Spark ML系列】LinearSVC原理源码继承关系分析

【Spark ML系列】LinearSVC原理源码继承关系分析

一、class LinearSVC

class LinearSVC @Since("2.2.0") (
    @Since("2.2.0") override val uid: String)
  extends Classifier[Vector, LinearSVC, LinearSVCModel]
  with LinearSVCParams with DefaultParamsWritable {

1. extends Classifier(通用类)

abstract class Classifier[
    FeaturesType,
    E <: Classifier[FeaturesType, E, M],
    M <: ClassificationModel[FeaturesType, M]]
  extends Predictor[FeaturesType, E, M]
  with ClassifierParams {

Classifier 是一个抽象类,继承自 PredictorClassifierParams

它定义了分类器的基本行为,并提供了一些公共方法,如获取类别数量转换数据集

实现类
在这里插入图片描述

/**
 * 单标签二分类或多分类。
 * 类别被索引为 {0, 1, ..., numClasses - 1}。
 *
 * @tparam FeaturesType  输入特征的类型,例如 `Vector`
 * @tparam E  具体的估计器类型
 * @tparam M  具体的模型类型
 */
abstract class Classifier[
    FeaturesType,
    E <: Classifier[FeaturesType, E, M],
    M <: ClassificationModel[FeaturesType, M]]
  extends Predictor[FeaturesType, E, M] with ClassifierParams {

  /**
   * 获取类别数量。首先查找列元数据中的类别数值,如果缺失,则假设类别被索引为 0,1,...,numClasses-1,
   * 并通过找到最大标签值来计算类别数。
   *
   * 类别验证(确保所有类别都是整数且 >= 0)需要在其他地方处理,比如在 `extractLabeledPoints()` 中。
   *
   * @param dataset       包含列 [[labelCol]] 的数据集
   * @param maxNumClasses 从数据中推断时允许的最大类别数。如果元数据中指定了 numClasses,则忽略 maxNumClasses。
   * @return 类别数量
   * @throws IllegalArgumentException 如果元数据未指定 numClasses,并且实际的 numClasses 超过了 maxNumClasses
   */
  protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 100): Int = {
    DatasetUtils.getNumClasses(dataset, $(labelCol), maxNumClasses)
  }

  /** @group setParam */
  def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]

  // TODO: defaultEvaluator (follow-up PR)
}

1.1 with ClassifierParams(所有分类器)

/**  分类器的参数。 */
private[spark] trait ClassifierParams
  extends PredictorParams 
  with HasRawPredictionCol {
  override protected def validateAndTransformSchema(
      schema: StructType,
      fitting: Boolean,
      featuresDataType: DataType): StructType = {
    val parentSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
    SchemaUtils.appendColumn(parentSchema, $(rawPredictionCol), new VectorUDT)
  }
}

ClassifierParams 是分类器的参数特质,继承自 PredictorParamsHasRawPredictionCol。它重写了 validateAndTransformSchema 方法来验证并转换输入数据集的模式

1.1.1 extends PredictorParams(预测回归和分类的参数)

PredictorParams 是预测器的参数特质,继承自 Params,并包含了标签列、特征列和预测列的参数,用于预测(回归和分类)的参数特质

 /**
 * (private[ml]) 用于预测(回归和分类)的参数特质。
 */
private[ml] trait PredictorParams extends Params
  with HasLabelCol with HasFeaturesCol with HasPredictionCol {

  /**
   * 使用提供的参数映射验证和转换输入模式。
   *
   * @param schema 输入模式
   * @param fitting 是否为拟合过程
   * @param featuresDataType FeaturesType 的 SQL 数据类型。
   *                         例如,对于向量特征使用 `VectorUDT`。
   * @return 输出模式
   */
  protected def validateAndTransformSchema(
      schema: StructType,
      fitting: Boolean,
      featuresDataType: DataType): StructType = {
    // TODO: 支持将 Array[Double] 和 Array[Float] 转换为 Vector,当 FeaturesType = Vector 时
    SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
    if (fitting) {
      SchemaUtils.checkNumericType(schema, $(labelCol))

      this match {
        case p: HasWeightCol =>
          if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
            SchemaUtils.checkNumericType(schema, $(p.weightCol))
          }
        case _ =>
      }
    }
    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
  }
}
1.1.2 with HasRawPredictionCol

1.2 extends Predictor

abstract class Predictor[
    FeaturesType,
    Learner <: Predictor[FeaturesType, Learner, M],
    M <: PredictionModel[FeaturesType, M]]
  extends Estimator[M]
  with PredictorParams {

Predictor 是一个抽象类,继承自 EstimatorPredictorParams。它定义了预测器的基本行为,并提供了一些公共方法,如设置标签列、特征列预测列、拟合模型等.

实现类
在这里插入图片描述
在这里插入图片描述

/**
 * 预测问题(回归和分类)的抽象类。它接受所有数值类型的标签,并在 `fit()` 中自动将其转换为 DoubleType。
 * 如果该预测器支持权重,则它接受所有数值类型的权重,将在 `fit()` 中自动转换为 DoubleType。
 *
 * @tparam FeaturesType 特征的类型
 *                      例如,对于向量特征使用 `VectorUDT`。
 * @tparam Learner 该类的具体实现。如果您继承了此类型,请使用此类型参数来指定具体类型。
 * @tparam M 该类的具体实现,继承自 [[PredictionModel]]。如果您继承了此类型,请使用此类型参数来指定相应模型的具体类型。
 */
abstract class Predictor[
    FeaturesType,
    Learner <: Predictor[FeaturesType, Learner, M],
    M <: PredictionModel[FeaturesType, M]]
  extends Estimator[M] with PredictorParams {

  /** @group setParam */
  def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]

  /** @group setParam */
  def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]

  /** @group setParam */
  def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]

  override def fit(dataset: Dataset[_]): M = {
    // 处理一些内容,例如模式验证。
    // 开发者只需要实现 train() 方法。
    transformSchema(dataset.schema, logging = true)

    // 将 LabelCol 转换为 DoubleType 并保留元数据。
    val labelMeta = dataset.schema($(labelCol)).metadata
    val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)

    // 将 WeightCol 转换为 DoubleType 并保留元数据。
    val casted = this match {
      case p: HasWeightCol =>
        if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
          val weightMeta = dataset.schema($(p.weightCol)).metadata
          labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
        } else {
          labelCasted
        }
      case _ => labelCasted
    }

    copyValues(train(casted).setParent(this))
  }

  override def copy(extra: ParamMap): Learner

  /**
   * 使用给定的数据集和参数训练模型。
   * 开发者可以实现此方法来替代 `fit()`,以避免处理模式验证并将参数复制到模型中。
   *
   * @param dataset 训练数据集
   * @return 拟合的模型
   */
  protected def train(dataset: Dataset[_]): M

  /**
   * 返回与 FeaturesType 类型参数对应的 SQL 数据类型。
   *
   * 这用于 `validateAndTransformSchema()`。
   * 这个解决方案是因为 Scala 和 Java 在 SQL 上有不同的 API。
   *
   * 默认值为 VectorUDT,但如果 FeaturesType 不是向量,则可能会被重写。
   */
  private[ml] def featuresDataType: DataType = new VectorUDT

  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema, fitting = true, featuresDataType)
  }
}
1.2.1 extends Estimator[M]

抽象类 Estimator,用于将模型拟合到数据上。它是 Spark ML 中的一个基本概念,用于表示机器学习算法中的训练过程。

Estimator 类具有以下功能

  • 定义了三个 fit 方法,用于将模型拟合到输入数据上。其中,第一个 fit 方法使用可选参数对单个模型进行拟合,第二个 fit 方法使用提供的参数映射对单个模型进行拟合,第三个 fit 方法用于拟合多个模型。
  • fit 方法根据传入的参数进行模型拟合,并返回拟合后的模型。
  • fit 方法可以重写,以实现特定的算法优化。
  • copy 方法用于复制 Estimator 对象,并在复制对象中添加额外的参数。

Estimator 类为具体的机器学习算法提供了一个统一的接口,使得用户可以方便地使用和扩展。用户可以继承 Estimator 类并实现自己的算法逻辑。

/**
  * 用于将模型拟合到数据的估计器的抽象类。
  */
abstract class Estimator[M <: Model[M]] extends PipelineStage {

  /**
    * 使用可选参数将单个模型拟合到输入数据中。
    *
    * @param dataset 输入数据集
    * @param firstParamPair 第一个参数对,覆盖嵌入参数
    * @param otherParamPairs 其他参数对。这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
    * @return 拟合的模型
    */
  @Since("2.0.0")
  @varargs
  def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
    val map = new ParamMap()
      .put(firstParamPair)
      .put(otherParamPairs: _*)
    fit(dataset, map)
  }

  /**
    * 使用提供的参数映射将单个模型拟合到输入数据中。
    *
    * @param dataset 输入数据集
    * @param paramMap 参数映射。
    *                 这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
    * @return 拟合的模型
    */
  @Since("2.0.0")
  def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
    copy(paramMap).fit(dataset)
  }

  /**
    * 将模型拟合到输入数据中。
    */
  @Since("2.0.0")
  def fit(dataset: Dataset[_]): M

  /**
    * 使用多个参数集将多个模型拟合到输入数据中。
    * 默认实现在每个参数映射上使用for循环。
    * 子类可以重写此方法以优化多模型训练。
    *
    * @param dataset 输入数据集
    * @param paramMaps 参数映射的数组。
    *                  这些值会覆盖此估计器的嵌入ParamMap中指定的任何值。
    * @return 拟合的模型,与输入参数映射相匹配
    */
  @Since("2.0.0")
  def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = {
    paramMaps.map(fit(dataset, _))
  }

  override def copy(extra: ParamMap): Estimator[M]
}
1.2.2 with PredictorParams
1.2.3 M <: PredictionModel
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
  extends Model[M] 
  with PredictorParams {

PredictionModel 是预测模型的抽象类,继承自 ModelPredictorParams

它提供了一些通用的方法,如设置特征列和预测列获取特征数、进行转换操作等。

/**
 * 预测任务(回归和分类)的模型抽象类。
 *
 * @tparam FeaturesType 特征的类型
 *                      例如,对于向量特征使用 `VectorUDT`。
 * @tparam M 具体实现的 [[PredictionModel]]。如果您继承了此类型,请使用此类型参数来指定相应模型的具体类型。
 */
abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
  extends Model[M] with PredictorParams {

  /** @group setParam */
  def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]

  /** @group setParam */
  def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]

  /** 返回模型训练时使用的特征数。如果未知,则返回 -1。*/
  @Since("1.6.0")
  def numFeatures: Int = -1

  /**
   * 返回与 FeaturesType 类型参数对应的 SQL 数据类型。
   *
   * 这用于 `validateAndTransformSchema()`。
   * 这个解决方案是因为 Scala 和 Java 在 SQL 上有不同的 API。
   *
   * 默认值为 VectorUDT,但如果 FeaturesType 不是向量,则可能会被重写。
   */
  protected def featuresDataType: DataType = new VectorUDT

  override def transformSchema(schema: StructType): StructType = {
    var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
    if ($(predictionCol).nonEmpty) {
      outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
    }
    outputSchema
  }

  /**
   * 通过读取 [[featuresCol]],调用 `predict` 方法,并将预测结果存储为新列 [[predictionCol]] 来转换数据集。
   *
   * @param dataset 输入数据集
   * @return 具有类型为 `Double` 的 [[predictionCol]] 的转换后的数据集
   */
  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    if ($(predictionCol).nonEmpty) {
      transformImpl(dataset)
    } else {
      this.logWarning(s"$uid: Predictor.transform() 不执行任何操作,因为未设置输出列。")
      dataset.toDF
    }
  }

  protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    val outputSchema = transformSchema(dataset.schema, logging = true)
    val predictUDF = udf { features: Any =>
      predict(features.asInstanceOf[FeaturesType])
    }
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))),
      outputSchema($(predictionCol)).metadata)
  }

  /**
   * 针对给定的特征预测标签。
   * 此方法用于实现 `transform()` 并输出 [[predictionCol]]。
   */
  @Since("2.4.0")
  def predict(features: FeaturesType): Double
}

1.3 M <: ClassificationModel

abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
  extends PredictionModel[FeaturesType, M] 
  with ClassifierParams {

ClassificationModel 是分类模型的抽象类,继承自 PredictionModelClassifierParams。它提供了一些通用的方法,如设置原始预测列和预测列、获取类别数量、进行转换操作等。

1.3.1 extends PredictionModel
1.3.2 with ClassifierParams

实现类
在这里插入图片描述

源码

/**
* 由 [[Classifier]] 生成的模型。
* 类别被索引为 {0, 1, ..., numClasses - 1}。
*
* @tparam FeaturesType  输入特征的类型,例如 `Vector`
* @tparam M  具体的模型类型
*/
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
 extends PredictionModel[FeaturesType, M] with ClassifierParams {

 /** @group setParam */
 def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]

 /** 类别数(标签可以取的值的数量)。*/
 def numClasses: Int

 override def transformSchema(schema: StructType): StructType = {
   var outputSchema = super.transformSchema(schema)
   if ($(predictionCol).nonEmpty) {
     outputSchema = SchemaUtils.updateNumValues(schema,
       $(predictionCol), numClasses)
   }
   if ($(rawPredictionCol).nonEmpty) {
     outputSchema = SchemaUtils.updateAttributeGroupSize(outputSchema,
       $(rawPredictionCol), numClasses)
   }
   outputSchema
 }

 /**
  * 通过读取 [[featuresCol]] 并根据参数指定的方式添加新列进行数据集转换:
  *  - 将预测标签作为类型为 `Double` 的 [[predictionCol]]
  *  - 将原始预测值(置信度)作为类型为 `Vector` 的 [[rawPredictionCol]]
  *
  * @param dataset 输入数据集
  * @return 转换后的数据集
  */
 override def transform(dataset: Dataset[_]): DataFrame = {
   val outputSchema = transformSchema(dataset.schema, logging = true)

   // 只输出选定的列。
   // 这里稍微有些复杂,因为它尝试避免重复计算。
   var outputData = dataset
   var numColsOutput = 0
   if (getRawPredictionCol != "") {
     val predictRawUDF = udf { features: Any =>
       predictRaw(features.asInstanceOf[FeaturesType])
     }
     outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol)),
       outputSchema($(rawPredictionCol)).metadata)
     numColsOutput += 1
   }
   if (getPredictionCol != "") {
     val predCol = if (getRawPredictionCol != "") {
       udf(raw2prediction _).apply(col(getRawPredictionCol))
     } else {
       val predictUDF = udf { features: Any =>
         predict(features.asInstanceOf[FeaturesType])
       }
       predictUDF(col(getFeaturesCol))
     }
     outputData = outputData.withColumn(getPredictionCol, predCol,
       outputSchema($(predictionCol)).metadata)
     numColsOutput += 1
   }

   if (numColsOutput == 0) {
     logWarning(s"$uid: ClassificationModel.transform() 未执行任何操作,因为未设置输出列。")
   }
   outputData.toDF
 }

 final override def transformImpl(dataset: Dataset[_]): DataFrame =
   throw new UnsupportedOperationException(s"不支持在 $getClass 中调用 transformImpl 方法")

 /**
  * 针对给定的特征预测标签。
  * 这个方法用于实现 `transform()` 并输出 [[predictionCol]]。
  *
  * 对于分类,默认实现是从 `predictRaw()` 中选择最大值的索引作为预测结果。
  */
 override def predict(features: FeaturesType): Double = {
   raw2prediction(predictRaw(features))
 }

 /**
  * 针对每个可能的类别进行原始预测。
  * "原始" 预测的含义在不同的算法之间可能有所不同,但它直观地给出了对每个可能类别的置信度(较大的值表示更高的置信度)。
  * 此内部方法用于实现 `transform()` 并输出 [[rawPredictionCol]]。
  *
  * @return 向量,其中第 i 个元素是类别 i 的原始预测值。
  *         这些原始预测值可以是任意实数,其中较大的值表示对该类别的更高置信度。
  */
 @Since("3.0.0")
 def predictRaw(features: FeaturesType): Vector

 /**
  * 根据给定的原始预测向量选择预测的标签。
  * 可以重写此方法以支持偏好特定标签的阈值。
  * @return 预测的标签
  */
 protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax

 /**
  * 如果已设置原始预测列和预测列,则此方法返回当前模型,
  * 否则会为它们生成新列,并将它们设置为当前模型的列。
  */
 private[classification] def findSummaryModel():
 (ClassificationModel[FeaturesType, M], String, String) = {
   val model = if ($(rawPredictionCol).isEmpty && $(predictionCol).isEmpty) {
     copy(ParamMap.empty)
       .setRawPredictionCol("rawPrediction_" + java.util.UUID.randomUUID.toString)
       .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
   } else if ($(rawPredictionCol).isEmpty) {
     copy(ParamMap.empty).setRawPredictionCol("rawPrediction_" +
       java.util.UUID.randomUUID.toString)
   } else if ($(predictionCol).isEmpty) {
     copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
   } else {
     this
   }
   (model, model.getRawPredictionCol, model.getPredictionCol)
 }
}

2.with LinearSVCParams

private[classification] trait LinearSVCParams 
extends ClassifierParams 
with HasRegParam with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
  with HasAggregationDepth with HasThreshold with HasMaxBlockSizeInMB 

这段代码定义了Apache Spark中线性支持向量机(Linear Support Vector Classification, LinearSVC)分类器的参数。使用这些参数, 可以在训练调整线性支持向量机模型时对其进行配置。

在这段代码中,我们定义了一个名为LinearSVCParams的特质,它包含了线性支持向量机分类器的一些参数。这些参数包括正则化参数、最大迭代次数、是否拟合截距、收敛阈值、是否进行标准化、权重列等。

特别要注意的是,在LinearSVCParams特质中,我们为二分类预测中的阈值参数添加了注释。对于线性支持向量机(LinearSVC),该阈值应用于原始预测值(rawPrediction),而不是概率。该阈值可以是任意实数,其中正无穷将使所有预测为0.0,负无穷将使所有预测为1.0。默认情况下,阈值为0.0。

此外,我们还设置了一些参数的默认值,例如正则化参数为0.0、最大迭代次数为100、是否拟合截距为true、收敛阈值为1E-6、是否进行标准化为true、聚合深度为2、最大块大小为0.0

2.1 extends ClassifierParams

 /** 线性支持向量机分类器的参数。*/
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
  with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
  with HasAggregationDepth with HasThreshold with HasMaxBlockSizeInMB {

  /**
   * 二分类预测中的阈值参数。
   * 对于线性支持向量机(LinearSVC),该阈值应用于原始预测值(rawPrediction),而不是概率。
   * 该阈值可以是任何实数,其中正无穷将使所有预测为0.0,
   * 负无穷将使所有预测为1.0。
   * 默认值为0.0。
   *
   * @group param
   */
  final override val threshold: DoubleParam = new DoubleParam(this, "threshold",
    "应用于原始预测值的二分类预测中的阈值")

  setDefault(regParam -> 0.0, maxIter -> 100, fitIntercept -> true, tol -> 1E-6,
    standardization -> true, threshold -> 0.0, aggregationDepth -> 2, maxBlockSizeInMB -> 0.0)
}

二、LinearSVCModel

class LinearSVCModel private[classification] (
    @Since("2.2.0") override val uid: String,
    @Since("2.2.0") val coefficients: Vector,
    @Since("2.2.0") val intercept: Double)
  extends ClassificationModel[Vector, LinearSVCModel]
  with LinearSVCParams 
 with MLWritable 
with HasTrainingSummary[LinearSVCTrainingSummary] {

1. extends ClassificationModel

2. with LinearSVCParams

注意:核心逻辑在此,以上全是陪衬(其他模型共用接口或类)。

以下代码定义了 LinearSVCModel及其相关辅助类和特质。它们实现了线性支持向量机模型的训练、预测和评估等功能,并提供了一些方法来处理模型的保存和加载。

/**
 * 由 [[LinearSVC]] 训练的线性支持向量机模型
 */
@Since("2.2.0")
class LinearSVCModel private[classification] (
    @Since("2.2.0") override val uid: String,
    @Since("2.2.0") val coefficients: Vector,
    @Since("2.2.0") val intercept: Double)
  extends ClassificationModel[Vector, LinearSVCModel]
  with LinearSVCParams with MLWritable with HasTrainingSummary[LinearSVCTrainingSummary] {

  @Since("2.2.0")
  override val numClasses: Int = 2

  @Since("2.2.0")
  override val numFeatures: Int = coefficients.size

  @Since("2.2.0")
  def setThreshold(value: Double): this.type = set(threshold, value)

  private val margin: Vector => Double = (features) => {
    BLAS.dot(features, coefficients) + intercept
  }

  /**
   * 获取训练集上模型的摘要。如果 `hasSummary` 为 false,则抛出异常
   */
  @Since("3.1.0")
  override def summary: LinearSVCTrainingSummary = super.summary

  /**
   * 在测试数据集上评估模型。
   *
   * @param dataset 要在其上评估模型的测试数据集。
   */
  @Since("3.1.0")
  def evaluate(dataset: Dataset[_]): LinearSVCSummary = {
    val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
    // 处理可能缺失或无效的 rawPrediction 或 prediction 列
    val (summaryModel, rawPrediction, predictionColName) = findSummaryModel()
    new LinearSVCSummaryImpl(summaryModel.transform(dataset),
      rawPrediction, predictionColName, $(labelCol), weightColName)
  }

  override def predict(features: Vector): Double = {
    if (margin(features) > $(threshold)) 1.0 else 0.0
  }

  @Since("3.0.0")
  override def predictRaw(features: Vector): Vector = {
    val m = margin(features)
    Vectors.dense(-m, m)
  }

  override protected def raw2prediction(rawPrediction: Vector): Double = {
    if (rawPrediction(1) > $(threshold)) 1.0 else 0.0
  }

  @Since("2.2.0")
  override def copy(extra: ParamMap): LinearSVCModel = {
    copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)
  }

  @Since("2.2.0")
  override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)

  @Since("3.0.0")
  override def toString: String = {
    s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
  }
}

/**
 * `LinearSVCModel` 的可读取器
 */
@Since("2.2.0")
object LinearSVCModel extends MLReadable[LinearSVCModel] {

  @Since("2.2.0")
  override def read: MLReader[LinearSVCModel] = new LinearSVCReader

  @Since("2.2.0")
  override def load(path: String): LinearSVCModel = super.load(path)

  /** [[LinearSVCModel]] 的 [[MLWriter]] 实例 */
  private[LinearSVCModel]
  class LinearSVCWriter(instance: LinearSVCModel)
    extends MLWriter with Logging {

    private case class Data(coefficients: Vector, intercept: Double)

    override protected def saveImpl(path: String): Unit = {
      // 保存元数据和参数
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val data = Data(instance.coefficients, instance.intercept)
      val dataPath = new Path(path, "data").toString
      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }
  }

  private class LinearSVCReader extends MLReader[LinearSVCModel] {

    /** 加载模型时与元数据进行校验 */
    private val className = classOf[LinearSVCModel].getName

    override def load(path: String): LinearSVCModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.format("parquet").load(dataPath)
      val Row(coefficients: Vector, intercept: Double) =
        data.select("coefficients", "intercept").head()
      val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
      metadata.getAndSetParams(model)
      model
    }
  }
}

/**
 * 线性支持向量机结果的抽象类
 */
sealed trait LinearSVCSummary extends BinaryClassificationSummary

/**
 * 线性支持向量机训练结果的抽象类
 */
sealed trait LinearSVCTrainingSummary extends LinearSVCSummary with TrainingSummary

/**
 * 给定模型的线性支持向量机结果
 *
 * @param predictions 模型 `transform` 方法输出的 DataFrame
 * @param scoreCol "predictions" 中给出每个实例的原始预测值的字段
 * @param predictionCol "predictions" 中给出数据实例的预测值的字段,类型为 double
 * @param labelCol "predictions" 中给出每个实例的真实标签的字段
 * @param weightCol "predictions" 中给出每个实例的权重的字段
 */
private class LinearSVCSummaryImpl(
    @transient override val predictions: DataFrame,
    override val scoreCol: String,
    override val predictionCol: String,
    override val labelCol: String,
    override val weightCol: String)
  extends LinearSVCSummary

/**
 * 线性支持向量机训练结果
 *
 * @param predictions 模型 `transform` 方法输出的 DataFrame
 * @param scoreCol "predictions" 中给出每个实例的原始预测值的字段
 * @param predictionCol "predictions" 中给出数据实例的预测值的字段,类型为 double
 * @param labelCol "predictions" 中给出每个实例的真实标签的字段
 * @param weightCol "predictions" 中给出每个实例的权重的字段
 * @param objectiveHistory 每次迭代的目标函数(经过缩放的损失 + 正则化项)
 */
private class LinearSVCTrainingSummaryImpl(
    predictions: DataFrame,
    scoreCol: String,
    predictionCol: String,
    labelCol: String,
    weightCol: String,
    override val objectiveHistory: Array[Double])
  extends LinearSVCSummaryImpl(
    predictions, scoreCol, predictionCol, labelCol, weightCol)
    with LinearSVCTrainingSummary

三、LinearSVC原理

线性支持向量机(Linear Support Vector Machine,简称 Linear SVM)是一种经典的二分类算法,它基于支持向量机(SVM)算法并使用线性核函数。

线性SVM的原理如下:

  1. 数据预处理:首先对输入数据进行预处理,包括特征缩放和特征选择等操作。这可以提高算法的性能和收敛速度。

  2. 定义目标变量和特征变量:将待分类的样本数据分为两类,分别标记为正例和负例。同时,确定用于分类的特征变量。

  3. 寻找最佳超平面:线性SVM的目标是在特征空间中找到一个最佳的超平面,将正例和负例分开。这个超平面被称为决策边界。

  4. 定义优化问题:线性SVM的优化问题是通过最大化间隔来找到最佳超平面。间隔指的是从训练样本到超平面的最小距离。最大化间隔可以增加模型的泛化能力。

  5. 解决优化问题:将优化问题转化为凸优化问题,并使用二次规划等方法求解。通过求解这个优化问题,可以得到最佳的超平面参数。

  6. 预测新样本:在训练完成后,可以使用训练得到的超平面对新样本进行分类预测。根据样本点在超平面的位置,判断其属于正例还是负例。

线性SVM的优点包括:

  • 线性SVM在高维空间中表现良好,并且可以处理大规模数据集。
  • 通过最大化间隔,线性SVM能够提高模型的泛化能力,降低过拟合风险。
  • 线性SVM对于异常值和噪声具有较好的鲁棒性。

然而,

线性SVM也存在一些限制:

  • 线性SVM只适用于线性可分的数据集。当数据集不能被一个超平面完全分开时,线性SVM无法很好地工作。
  • 线性SVM对于处理大量特征的数据集可能会出现计算复杂度高的问题。这时需要使用特征选择等方法来减少特征数量。

总的来说,线性SVM是一种强大的二分类算法,尤其适用于线性可分的问题。它通过寻找最佳超平面,实现了高性能和泛化能力的平衡。

四、“LinearSVC” 和 “Linear SVM” 一样吗?

“LinearSVC” 和 “Linear SVM” 是指同一个算法,即线性支持向量机(Linear Support Vector Machine)。

在Scikit-learn库中,“LinearSVC” 是用于实现线性支持向量机分类器的类名。它使用线性核函数,并采用一对多(One-vs-Rest)策略处理多类分类问题。这个类提供了一些参数和方法,用于调整模型的超参数和进行预测。

而 “Linear SVM” 是对线性支持向量机算法的一种常见命名方式。它强调了算法的核心思想和特点:使用线性核函数,在高维空间中找到一个最佳的超平面来分割数据。

因此,可以认为 “LinearSVC” 和 “Linear SVM” 是等价的,都指代了基于线性核函数的支持向量机算法。它们都适用于处理线性可分的分类问题,并具有相似的原理和功能。

  • 34
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值