【Spark ML系列】 LogisticRegression源码分析

Spark ML LogisticRegression LogisticRegressionModel源码分析

1. 源码适用场景

LogisticRegression类是Spark ML中的逻辑回归模型类。它支持多分类和二分类问题,并且可以通过LBFGS/OWLQN进行传统逻辑回归模型的拟合,或者通过LBFGSB进行边界(箱式)约束逻辑回归模型的拟合。

LogisticRegressionModel 类是 Apache Spark ML 中逻辑回归模型的表示。逻辑回归模型是一种常见的分类模型,可以用于二元分类和多类分类问题。该类提供了获取模型系数、截距以及进行预测和评估的功能。

2. 多种主要用法及其代码示例

  • 创建LogisticRegression对象
import org.apache.spark.ml.classification.LogisticRegression

val lr = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.01)
  • 设置参数并训练模型
import org.apache.spark.ml.classification.LogisticRegression

val lr = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.01)

val model = lr.fit(trainingData)
  • 创建二元逻辑回归模型并获取系数向量和截距:
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.linalg.{DenseVector, Vectors}

val coefficients = Vectors.dense(0.5, 0.3)
val intercept = 0.1
val model = new LogisticRegressionModel("lrModel", coefficients, intercept)
val coefficientVector: DenseVector = model.coefficients.asInstanceOf[DenseVector]
val interceptValue: Double = model.intercept
  • 在测试数据集上评估模型:
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.evaluation.BinaryLogisticRegressionSummary
import org.apache.spark.sql.Dataset

def evaluateModel(model: LogisticRegressionModel, dataset: Dataset[_]): BinaryLogisticRegressionSummary = {
  val summary = model.evaluate(dataset)
  summary
}

val lrModel: LogisticRegressionModel = ???
val testDataset: Dataset[_] = ???
val summary: BinaryLogisticRegressionSummary = evaluateModel(lrModel, testDataset)

3. 中文源码

LogisticRegression

/**
 * 逻辑回归。支持:
 *  - 多项逻辑回归(softmax回归)。
 *  - 二项逻辑回归。
 *
 * 这个类通过LBFGS/OWLQN支持传统的逻辑回归模型拟合,
 * 或者通过LBFGSB支持边界约束的逻辑回归模型拟合。
 */
@Since("1.2.0")
class LogisticRegression @Since("1.2.0") (
    @Since("1.4.0") override val uid: String)
  extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
  with LogisticRegressionParams with DefaultParamsWritable with Logging {

  // 构造函数
  @Since("1.4.0")
  def this() = this(Identifiable.randomUID("logreg"))

  /**
   * 设置正则化参数。
   * 默认为0.0。
   *
   * @group setParam
   */
  @Since("1.2.0")
  def setRegParam(value: Double): this.type = set(regParam, value)
  setDefault(regParam -> 0.0)

  /**
   * 设置ElasticNet混合参数。
   * 当alpha = 0时,惩罚是L2惩罚。
   * 当alpha = 1时,惩罚是L1惩罚。
   * 当alpha在(0,1)之间时,惩罚是L1和L2的组合。
   * 默认为0.0,即L2惩罚。
   *
   * 注意:在边界约束优化下拟合只支持L2正则化,
   * 所以如果该参数的值非零会抛出异常。
   *
   * @group setParam
   */
  @Since("1.4.0")
  def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
  setDefault(elasticNetParam -> 0.0)

  /**
   * 设置最大迭代次数。
   * 默认为100。
   *
   * @group setParam
   */
  @Since("1.2.0")
  def setMaxIter(value: Int): this.type = set(maxIter, value)
  setDefault(maxIter -> 100)

  /**
   * 设置迭代的收敛容差。
   * 较小的值会在更多迭代次数下获得更高的精度。
   * 默认为1E-6。
   *
   * @group setParam
   */
  @Since("1.4.0")
  def setTol(value: Double): this.type = set(tol, value)
  setDefault(tol -> 1E-6)

  /**
   * 是否拟合截距项。
   * 默认为true。
   *
   * @group setParam
   */
  @Since("1.4.0")
  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
  setDefault(fitIntercept -> true)

  /**
   * 设置参数[[family]]的值。
   * 默认为"auto"。
   *
   * @group setParam
   */
  @Since("2.1.0")
  def setFamily(value: String): this.type = set(family, value)
  setDefault(family -> "auto")

  /**
   * 是否在拟合模型之前对训练特征进行标准化。
   * 模型的系数将始终以原始比例返回,
   * 因此对于用户来说是透明的。注意,无论是否标准化,
   * 当没有应用正则化时,模型应该始终收敛到相同的解。
   * 在R的GLMNET软件包中,默认行为也是true。
   * 默认为true。
   *
   * @group setParam
   */
  @Since("1.5.0")
  def setStandardization(value: Boolean): this.type = set(standardization, value)
  setDefault(standardization -> true)

  @Since("1.5.0")
  override def setThreshold(value: Double): this.type = super.setThreshold(value)
  setDefault(threshold -> 0.5)

  @Since("1.5.0")
  override def getThreshold: Double = super.getThreshold

  /**
   * 设置参数[[weightCol]]的值。
   * 如果未设置或为空,则将所有实例权重视为1.0。
   * 默认未设置,因此所有实例的权重都为1.0。
   *
   * @group setParam
   */
  @Since("1.6.0")
  def setWeightCol(value: String): this.type = set(weightCol, value)

  @Since("1.5.0")
  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)

  @Since("1.5.0")
  override def getThresholds: Array[Double] = super.getThresholds

  /**
   * 对于treeAggregate推荐的深度(大于等于2)。
   * 如果特征的维度或分区的数量很大,
   * 可以将此参数调整为更大的值。
   * 默认为2。
   *
   * @group expertSetParam
   */
  @Since("2.1.0")
  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
  setDefault(aggregationDepth -> 2)

  /**
   * 设置边界约束优化下系数的下界。
   *
   * @group expertSetParam
   */
  @Since("2.2.0")
  def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value)

  /**
   * 设置边界约束优化下系数的上界。
   *
   * @group expertSetParam
   */
  @Since("2.2.0")
  def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value)

  /**
   * 设置边界约束优化下截距的下界。
   *
   * @group expertSetParam
   */
  @Since("2.2.0")
  def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value)

  /**
   * 设置边界约束优化下截距的上界。
   *
   * @group expertSetParam
   */
  @Since("2.2.0")
  def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)

  private def assertBoundConstrainedOptimizationParamsValid(
      numCoefficientSets: Int,
      numFeatures: Int): Unit = {
    if (isSet(lowerBoundsOnCoefficients)) {
      require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
        $(lowerBoundsOnCoefficients).numCols == numFeatures,
        s"LowerBoundsOnCoefficients的形状必须与二项逻辑回归相兼容,即(1,特征数)" +
          s"或多项逻辑回归为(类别数,特征数),但是找到了:" +
          s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols})。")
    }
    if (isSet(upperBoundsOnCoefficients)) {
      require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
        $(upperBoundsOnCoefficients).numCols == numFeatures,
        s"UpperBoundsOnCoefficients的形状必须与二项逻辑回归相兼容,即(1,特征数)" +
          s"或多项逻辑回归为(类别数,特征数),但是找到了:" +
          s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols})。")
    }
    if (isSet(lowerBoundsOnIntercepts)) {
      require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "LowerBoundsOnIntercepts的大小" +
        "必须等于1(二项逻辑回归)或多项逻辑回归的类别数,但是找到了:" +
        s"${getLowerBoundsOnIntercepts.size}。")
    }
    if (isSet(upperBoundsOnIntercepts)) {
      require($(upperBoundsOnIntercepts).size == numCoefficientSets, "UpperBoundsOnIntercepts的大小" +
        "必须等于1(二项逻辑回归)或多项逻辑回归的类别数,但是找到了:" +
        s"${getUpperBoundsOnIntercepts.size}。")
    }
    if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
      require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
        .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients应该始终小于等于UpperBoundsOnCoefficients," +
        s"但是找到了:lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients," +
        s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients。")
    }
    if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
      require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
        .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts应该始终小于等于UpperBoundsOnIntercepts," +
        s"但是找到了:lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts," +
        s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts。")
    }
  }

  private var optInitialModel: Option[LogisticRegressionModel] = None

  private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
    this.optInitialModel = Some(model)
    this
  }

  override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
    val handlePersistence = dataset.storageLevel == StorageLevel.NONE
    train(dataset, handlePersistence)
  }

  protected[spark] def train(
      dataset: Dataset[_],
      handlePersistence: Boolean): LogisticRegressionModel = instrumented { instr =>
    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
    val instances: RDD[Instance] =
      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
        case Row(label: Double, weight: Double, features: Vector) =>
          Instance(label, weight, features)
      }

    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)

    instr.logPipelineStage(this)
    instr.logDataset(dataset)
    instr.logParams(this, regParam, elasticNetParam, standardization, threshold,
      maxIter, tol, fitIntercept)

    val (summarizer, labelSummarizer) = {
      val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
        instance: Instance) =>
          (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))

      val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
        c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
          (c1._1.merge(c2._1), c1._2.merge(c2._2))

      instances.treeAggregate(
        (new MultivariateOnlineSummarizer, new MultiClassSummarizer)
      )(seqOp, combOp, $(aggregationDepth))
    }
    instr.logNumExamples(summarizer.count)
    instr.logNamedValue("lowestLabelWeight", labelSummarizer.histogram.min.toString)
    instr.logNamedValue("highestLabelWeight", labelSummarizer.histogram.max.toString)

    val histogram = labelSummarizer.histogram
    val numInvalid = labelSummarizer.countInvalid
    val numFeatures = summarizer.mean.size
    val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures

    val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
      case Some(n: Int) =>
        require(n >= histogram.length, s"指定的类别数$n小于唯一标签数${histogram.length}。")
        n
      case None => histogram.length
    }

    val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
  case "binomial" =>
    require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
      s"outcome classes but found $numClasses.")
    false
  case "multinomial" => true
  case "auto" => numClasses > 2
  case other => throw new IllegalArgumentException(s"Unsupported family: $other")
}
val numCoefficientSets = if (isMultinomial) numClasses else 1

// 如果使用边界约束优化,则检查参数交互是否有效。
if (usingBoundConstrainedOptimization) {
  assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures)
}

// 如果定义了阈值参数,则检查阈值参数的长度是否与类别数一致。
if (isDefined(thresholds)) {
  require($(thresholds).length == numClasses, this.getClass.getSimpleName +
    ".train() called with non-matching numClasses and thresholds.length." +
    s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

// 记录日志:类别数和特征数
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)

val (coefficientMatrix, interceptVector, objectiveHistory) = {
  // 如果存在无效标签,则抛出异常
  if (numInvalid != 0) {
    val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
      s"Found $numInvalid invalid labels."
    instr.logError(msg)
    throw new SparkException(msg)
  }

  // 检查是否所有的标签都是相同的值,并且fitIntercept=true
  val isConstantLabel = histogram.count(_ != 0.0) == 1

  if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
    // 如果所有标签都是相同的值,并且fitIntercept=true,那么系数将全部为零,不需要训练。
    instr.logWarning(s"All labels are the same value and fitIntercept=true, so the " +
      s"coefficients will be zeros. Training is not needed.")
    val constantLabelIndex = Vectors.dense(histogram).argmax
    val coefMatrix = new SparseMatrix(numCoefficientSets, numFeatures,
      new Array[Int](numCoefficientSets + 1), Array.empty[Int], Array.empty[Double],
      isTransposed = true).compressed
    val interceptVec = if (isMultinomial) {
      Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity)))
    } else {
      Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
    }
    (coefMatrix, interceptVec, Array.empty[Double])
  } else {
    // 如果不满足上述条件,则进行正常的逻辑回归模型训练流程

    // 如果不使用截距项并且所有标签属于单一类别,给出警告
    if (!$(fitIntercept) && isConstantLabel) {
      instr.logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
        s"dangerous ground, so the algorithm may not converge.")
    }

    // 计算特征的均值和标准差
    val featuresMean = summarizer.mean.toArray
    val featuresStd = summarizer.variance.toArray.map(math.sqrt)

    // 如果不使用截距项并且数据集中存在非零常数列,给出警告
    if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
      featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
      instr.logWarning("Fitting LogisticRegressionModel without intercept on dataset with " +
        "constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
        "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
    }

    // 计算正则化的L1项和L2项系数
    val regParamL1 = $(elasticNetParam) * $(regParam)
    val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)

    // 广播特征标准差
    val bcFeaturesStd = instances.context.broadcast(featuresStd)

    // 构造逻辑回归聚合器函数
    val getAggregatorFunc = new LogisticAggregator(bcFeaturesStd, numClasses, $(fitIntercept),
      multinomial = isMultinomial)(_)
    
    // 获取特征的标准差
    val getFeaturesStd = (j: Int) => if (j >= 0 && j < numCoefficientSets * numFeatures) {
      featuresStd(j / numCoefficientSets)
    } else {
      0.0
    }

    // 构建正则化方法
    val regularization = if (regParamL2 != 0.0) {
      val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures * numCoefficientSets
      Some(new L2Regularization(regParamL2, shouldApply,
        if ($(standardization)) None else Some(getFeaturesStd)))
    } else {
      None
    }

    // 构建损失函数
    val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
      $(aggregationDepth))

    // 计算系数和截距的总数量
    val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets

    // 如果使用边界约束优化,则计算下界和上界
    val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = {
      if (usingBoundConstrainedOptimization) {
        val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity)
        val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity)
        val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients)
        val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients)
        val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts)
        val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts)

        var i = 0
        while (i < numCoeffsPlusIntercepts) {
          val coefficientSetIndex = i % numCoefficientSets
          val featureIndex = i / numCoefficientSets
          if (featureIndex < numFeatures) {
            if (isSetLowerBoundsOnCoefficients) {
              lowerBounds(i) = $(lowerBoundsOnCoefficients)(
                coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
            }
            if (isSetUpperBoundsOnCoefficients) {
              upperBounds(i) = $(upperBoundsOnCoefficients)(
                coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
            }
          } else {
            if (isSetLowerBoundsOnIntercepts) {
              lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex)
            }
            if (isSetUpperBoundsOnIntercepts) {
              upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex)
            }
          }
          i += 1
        }
        (lowerBounds, upperBounds)
      } else {
        (null, null)
      }
    }

    // 根据正则化参数和弹性网络参数选择优化器
    val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
      if (lowerBounds != null && upperBounds != null) {
        new BreezeLBFGSB(
          BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol))
      } else {
        new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
      }
    } else {
      val standardizationParam = $(standardization)
      def regParamL1Fun = (index: Int) => {
        // Remove the L1 penalization on the intercept
        val isIntercept = $(fitIntercept) && index >= numFeatures * numCoefficientSets
        if (isIntercept) {
          0.0
        } else {
          if (standardizationParam) {
            regParamL1
          } else {
            val featureIndex = index / numCoefficientSets
            // If `standardization` is false, we still standardize the data
            // to improve the rate of convergence; as a result, we have to
            // perform this reverse standardization by penalizing each component
            // differently to get effectively the same objective function when
            // the training dataset is not standardized.
            if (featuresStd(featureIndex) != 0.0) {
              regParamL1 / featuresStd(featureIndex)
            } else {
              0.0
            }
          }
        }
      }
      new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
    }

    // 初始化系数矩阵
    val initialCoefWithInterceptMatrix =
      Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept)

    // 判断初始模型是否有效
    val initialModelIsValid = optInitialModel match {
      case Some(_initialModel) =>
        val providedCoefs = _initialModel.coefficientMatrix
        val modelIsValid = (providedCoefs.numRows == numCoefficientSets) &&
          (providedCoefs.numCols == numFeatures) &&
          (_initialModel.interceptVector.size == numCoefficientSets) &&
          (_initialModel.getFitIntercept == $(fitIntercept))
        if (!modelIsValid) {
          instr.logWarning(s"Initial coefficients will be ignored! Its dimensions " +
            s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
            s"expected size ($numCoefficientSets, $numFeatures)")
        }
        modelIsValid
      case None => false
    }

    // 如果初始模型有效,则使用提供的系数初始化
    if (initialModelIsValid) {
      val providedCoef = optInitialModel.get.coefficientMatrix
      providedCoef.foreachActive { (classIndex, featureIndex, value) =>
        // We need to scale the coefficients since they will be trained in the scaled space
        initialCoefWithInterceptMatrix.update(classIndex, featureIndex,
          value * featuresStd(featureIndex))
      }
      if ($(fitIntercept)) {
        optInitialModel.get.interceptVector.foreachActive { (classIndex, value) =>
          initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value)
        }
      }
    } else if ($(fitIntercept) && isMultinomial) {
      /*
         对于多项式逻辑回归,当系数初始化为零时,如果我们初始化截距项使其遵循标签的分布,收敛会更快。
         根据以下公式计算截距项的初始值:
         P(1) = \exp(b_1) / Z
         ...
         P(K) = \exp(b_K) / Z
         where Z = \sum_{k=1}^{K} \exp(b_k)
         由于这个问题存在多个解,满足上述条件的一种解是
         \exp(b_k) = count_k * \exp(\lambda)
         b_k = \log(count_k) * \lambda
         \lambda 是一个自由参数,所以选择使均值为0的相位 \lambda。得到如下结果:
         b_k = \log(count_k)
         b_k' = b_k - \mean(b_k)
       */
      val rawIntercepts = histogram.map(math.log1p) // add 1 for smoothing (log1p(x) = log(1+x))
      val rawMean = rawIntercepts.sum / rawIntercepts.length
      rawIntercepts.indices.foreach { i =>
        initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean)
      }
    } else if ($(fitIntercept)) {
      /*
         对于二元逻辑回归,当系数初始化为零时,如果我们初始化截距项使其遵循标签的分布,收敛会更快。
         根据以下公式计算截距项的初始值:
         P(0) = 1 / (1 + \exp(b)), and
         P(1) = \exp(b) / (1 + \exp(b))
         因此
         b = \log{P(1) / P(0)} = \log{count_1 / count_0}
       */
      initialCoefWithInterceptMatrix.update(0, numFeatures,
        math.log(histogram(1) / histogram(0)))
    }

    // 如果使用边界约束优化,则确保所有初始值位于相应的边界内。
    if (usingBoundConstrainedOptimization) {
      var i = 0
      while (i < numCoeffsPlusIntercepts) {
        val coefficientSetIndex = i % numCoefficientSets
        val featureIndex = i / numCoefficientSets
        if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) {
          initialCoefWithInterceptMatrix.update(
            coefficientSetIndex, featureIndex, lowerBounds(i))
        } else if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) {
          initialCoefWithInterceptMatrix.update(
            coefficientSetIndex, featureIndex, upperBounds(i))
        }
        i += 1
      }
    }

    // 使用优化器进行模型训练,得到最终的系数矩阵、截距向量和目标函数历史记录
    val states = optimizer.iterations(new CachedDiffFunction(costFun),
      new BDV[Double](initialCoefWithInterceptMatrix.toArray))

    val arrayBuilder = mutable.ArrayBuilder.make[Double]
    var state: optimizer.State = null
    while (states.hasNext) {
      state = states.next()
      arrayBuilder += state.adjustedValue
    }
    bcFeaturesStd.destroy(blocking = false)

    if (state == null) {
      val msg = s"${optimizer.getClass.getName} failed."
      instr.logError(msg)
      throw new SparkException(msg)
    }

    /*
    系数在缩放空间中进行训练;我们将它们转换回原始空间。

    此外,由于在训练过程中,为了避免额外的计算,系数按列主序排列,我们在将它们传递给模型之前将它们转换回行主序。

    需要注意的是,在缩放空间和原始空间中,截距是相同的;
    因此,不需要进行缩放。
         */
        val allCoefficients = state.x.toArray.clone()
        val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept,
          allCoefficients)
        val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures,
          new Array[Double](numCoefficientSets * numFeatures), isTransposed = true)
        val interceptVec = if ($(fitIntercept) || !isMultinomial) {
          Vectors.zeros(numCoefficientSets)
        } else {
          Vectors.sparse(numCoefficientSets, Seq.empty)
        }
        // 从组合矩阵中分离出截距和系数
        allCoefMatrix.foreachActive { (classIndex, featureIndex, value) =>
          val isIntercept = $(fitIntercept) && (featureIndex == numFeatures)
          if (!isIntercept && featuresStd(featureIndex) != 0.0) {
            denseCoefficientMatrix.update(classIndex, featureIndex,
              value / featuresStd(featureIndex))
          }
          if (isIntercept) interceptVec.toArray(classIndex) = value
        }

        if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) {
          /*
        当不应用正则化时,多项式系数缺乏可辨识性,
        因为我们没有使用一个基准类。可以向系数添加任何常量值并获得相同的似然度。
        因此,在这里,我们选择均值居中的系数以确保再现性。
        这种方法遵循glmnet中的方法,具体描述如下:

        Friedman等人,“Regularization Paths for Generalized Linear Models via
        Coordinate Descent”,https://core.ac.uk/download/files/153/6287975.pdf
           */
          val centers = Array.fill(numFeatures)(0.0)
          denseCoefficientMatrix.foreachActive { case (i, j, v) =>
            centers(j) += v
          }
          centers.transform(_ / numCoefficientSets)
          denseCoefficientMatrix.foreachActive { case (i, j, v) =>
            denseCoefficientMatrix.update(i, j, v - centers(j))
          }
        }

		// 在使用多项式算法时,对截距进行中心化
        if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) {
          val interceptArray = interceptVec.toArray
          val interceptMean = interceptArray.sum / interceptArray.length
          (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }
        }
        (denseCoefficientMatrix.compressed, interceptVec.compressed, arrayBuilder.result())
      }
    }

    if (handlePersistence) instances.unpersist()

    val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
      numClasses, isMultinomial))

    val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
    val logRegSummary = if (numClasses <= 2) {
      new BinaryLogisticRegressionTrainingSummaryImpl(
        summaryModel.transform(dataset),
        probabilityColName,
        predictionColName,
        $(labelCol),
        $(featuresCol),
        objectiveHistory)
    } else {
      new LogisticRegressionTrainingSummaryImpl(
        summaryModel.transform(dataset),
        probabilityColName,
        predictionColName,
        $(labelCol),
        $(featuresCol),
        objectiveHistory)
    }
    model.setSummary(Some(logRegSummary))
  }

  @Since("1.4.0")
  override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}

LogisticRegressionModel

@Since("1.4.0")
class LogisticRegressionModel private[spark] (
    @Since("1.4.0") override val uid: String,
    @Since("2.1.0") val coefficientMatrix: Matrix,
    @Since("2.1.0") val interceptVector: Vector,
    @Since("1.3.0") override val numClasses: Int,
    private val isMultinomial: Boolean)
  extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
  with LogisticRegressionParams with MLWritable {

  // 确保系数矩阵的行数与截距向量的大小相等
  require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " +
    s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " +
    s"${interceptVector.size}")

  // 构造函数重载,用于创建二元逻辑回归模型
  private[spark] def this(uid: String, coefficients: Vector, intercept: Double) =
    this(uid, new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true),
      Vectors.dense(intercept), 2, isMultinomial = false)

  /**
   * 获取"二元"逻辑回归模型的系数向量。如果该模型是使用"多元"家族训练的,则抛出异常。
   *
   * @return Vector
   */
  @Since("2.0.0")
  def coefficients: Vector = if (isMultinomial) {
    throw new SparkException("Multinomial models contain a matrix of coefficients, use " +
      "coefficientMatrix instead.")
  } else {
    _coefficients
  }

  // 将系数矩阵转换为合适的向量表示,而不复制数据
  private lazy val _coefficients: Vector = {
    require(coefficientMatrix.isTransposed,
      "LogisticRegressionModel coefficients should be row major for binomial model.")
    coefficientMatrix match {
      case dm: DenseMatrix => Vectors.dense(dm.values)
      case sm: SparseMatrix => Vectors.sparse(coefficientMatrix.numCols, sm.rowIndices, sm.values)
    }
  }

  /**
   * 获取"二元"逻辑回归模型的截距。如果该模型是使用"多元"家族训练的,则抛出异常。
   *
   * @return Double
   */
  @Since("1.3.0")
  def intercept: Double = if (isMultinomial) {
    throw new SparkException("Multinomial models contain a vector of intercepts, use " +
      "interceptVector instead.")
  } else {
    _intercept
  }

  private lazy val _intercept = interceptVector.toArray.head

  // 设置阈值
  @Since("1.5.0")
  override def setThreshold(value: Double): this.type = super.setThreshold(value)

  // 获取阈值
  @Since("1.5.0")
  override def getThreshold: Double = super.getThreshold

  // 设置多个阈值
  @Since("1.5.0")
  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)

  // 获取多个阈值
  @Since("1.5.0")
  override def getThresholds: Array[Double] = super.getThresholds

  /** 对于二元分类,计算类别1的边际值(原始预测)。 */
  private val margin: Vector => Double = (features) => {
    BLAS.dot(features, _coefficients) + _intercept
  }

  /** 计算每个类别标签的边际值(原始预测)。 */
  private val margins: Vector => Vector = (features) => {
    val m = interceptVector.toDense.copy
    BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
    m
  }

  /** 对于二元分类,计算类别1的分数(概率)。 */
  private val score: Vector => Double = (features) => {
    val m = margin(features)
    1.0 / (1.0 + math.exp(-m))
  }

  // 特征数量
  @Since("1.6.0")
  override val numFeatures: Int = coefficientMatrix.numCols

  private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None

  /**
   * 获取模型在训练集上的摘要。如果`trainingSummary == None`,则抛出异常。
   */
  @Since("1.5.0")
  def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
    throw new SparkException("No training summary available for this LogisticRegressionModel")
  }

  /**
   * 获取模型在训练集上的摘要。如果`trainingSummary == None`或该模型为多类模型,则抛出异常。
   */
  @Since("2.3.0")
  def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
    case b: BinaryLogisticRegressionTrainingSummary => b
    case _ =>
      throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
        s"(numClasses=${numClasses}), use summary instead.")
  }

  /**
   * 如果概率和预测列被设置了,此方法返回当前模型,
   * 否则它为它们生成新的列,并将它们设置为当前模型的列。
   */
  private[classification] def findSummaryModel():
      (LogisticRegressionModel, String, String) = {
    val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
      copy(ParamMap.empty)
        .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
        .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
    } else if ($(probabilityCol).isEmpty) {
      copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
    } else if ($(predictionCol).isEmpty) {
      copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
    } else {
      this
    }
    (model, model.getProbabilityCol, model.getPredictionCol)
  }

  private[classification]
  def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
    this.trainingSummary = summary
    this
  }

  /** 指示该模型实例是否存在训练摘要。 */
  @Since("1.5.0")
  def hasSummary: Boolean = trainingSummary.isDefined

  /**
   * 对测试数据集评估模型。
   *
   * @param dataset 要在其上评估模型的测试数据集。
   */
  @Since("2.0.0")
  def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
    // 处理可能缺失或无效的预测列
    val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
    if (numClasses > 2) {
      new LogisticRegressionSummaryImpl(summaryModel.transform(dataset),
        probabilityColName, predictionColName, $(labelCol), $(featuresCol))
    } else {
      new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset),
        probabilityColName, predictionColName, $(labelCol), $(featuresCol))
    }
  }

  /**
   * 预测给定特征向量的类别标签。
   * 这个行为可以通过`thresholds`来调整。
   */
  override def predict(features: Vector): Double = if (isMultinomial) {
    super.predict(features)
  } else {
    // 注意:我们应该使用getThreshold而不是$(threshold),因为getThreshold已经被覆盖了。
    if (score(features) > getThreshold) 1 else 0
  }

  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
    rawPrediction match {
      case dv: DenseVector =>
        if (isMultinomial) {
          val size = dv.size
          val values = dv.values

          // 获取最大边际值
          val maxMarginIndex = rawPrediction.argmax
          val maxMargin = rawPrediction(maxMarginIndex)

          if (maxMargin == Double.PositiveInfinity) {
            var k = 0
            while (k < size) {
              values(k) = if (k == maxMarginIndex) 1.0 else 0.0
              k += 1
            }
          } else {
            val sum = {
              var temp = 0.0
              var k = 0
              while (k < numClasses) {
                values(k) = if (maxMargin > 0) {
                  math.exp(values(k) - maxMargin)
                } else {
                  math.exp(values(k))
                }
                temp += values(k)
                k += 1
              }
              temp
            }
            BLAS.scal(1 / sum, dv)
          }
          dv
        } else {
          var i = 0
          val size = dv.size
          while (i < size) {
            dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
            i += 1
          }
          dv
        }
      case sv: SparseVector =>
        throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
          " raw2probabilitiesInPlace encountered SparseVector")
    }
  }

  override protected def predictRaw(features: Vector): Vector = {
    if (isMultinomial) {
      margins(features)
    } else {
      val m = margin(features)
      Vectors.dense(-m, m)
    }
  }

  @Since("1.4.0")
  override def copy(extra: ParamMap): LogisticRegressionModel = {
    val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
      numClasses, isMultinomial), extra)
    newModel.setSummary(trainingSummary).setParent(parent)
  }

  override protected def raw2prediction(rawPrediction: Vector): Double = {
    if (isMultinomial) {
      super.raw2prediction(rawPrediction)
    } else {
      // 注意:我们应该使用getThreshold而不是$(threshold),因为getThreshold已经被覆盖了。
      val t = getThreshold
      val rawThreshold = if (t == 0.0) {
        Double.NegativeInfinity
      } else if (t == 1.0) {
        Double.PositiveInfinity
      } else {
        math.log(t / (1.0 - t))
      }
      if (rawPrediction(1) > rawThreshold) 1 else 0
    }
  }

  override protected def probability2prediction(probability: Vector): Double = {
    if (isMultinomial) {
      super.probability2prediction(probability)
    } else {
      // 注意:我们应该使用getThreshold而不是$(threshold),因为getThreshold已经被覆盖了。
      if (probability(1) > getThreshold) 1 else 0
    }
  }

  /**
   * 返回一个[[org.apache.spark.ml.util.MLWriter]]实例,用于保存此ML实例。
   *
   * 对于[[LogisticRegressionModel]],当前不会保存训练[[summary]]。
   * 将来可能会添加保存[[summary]]的选项。
   *
   * 目前也不会保存[[parent]]。
   */
  @Since("1.6.0")
  override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)

  override def toString: String = {
    s"LogisticRegressionModel: " +
    s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
  }
}

4. 官方链接

LogisticRegression源码

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值