深入理解spark LR实现源码

本文深入探讨了Spark中的LogisticRegression(LR)模型,从模型简介、使用方法到源码分析,全面解析LR的实现细节。通过实例展示了在mllib和ml库中如何应用及学习LR模型。
摘要由CSDN通过智能技术生成

LR模型简介

LR是LogisticRegression的简称,译为逻辑回归。它本质上等价于一个线性模型。

使用方法

import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils

// Load training data in LIBSVM format.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)

// 利用LBFGS优化器来训练模型
val model = new LogisticRegressionWithLBFGS()
  .setNumClasses(10)
  .run(training)

// Compute raw scores on the test set.
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
  val prediction = model.predict(features)
  (prediction, label)
}

// Get evaluation metrics.
val metrics = new MulticlassMetrics(predictionAndLabels)
val accuracy = metrics.accuracy
println(s"Accuracy = $accuracy")

// Save and load model
model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel")
val sameModel = LogisticRegressionModel.load(sc,
  "target/tmp/scalaLogisticRegressionWithLBFGSModel")

源码分析

以上代码中实例化一个mllib中的用LBFGS学习的LR对象:LR源码地址

@Since("1.1.0")
class LogisticRegressionWithLBFGS
  extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {

  this.setFeatureScaling(true)

  @Since("1.1.0")
  //优化器选择LBFGS,BFGS本身是一种拟牛顿近似算法,L指的是有限内存,对内存计算进行了优化改进
  override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)

  override protected val validators = List(multiLabelValidator)

  private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
    if (numOfLinearPredictor > 1) {
      DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data)
    } else {
      DataValidators.binaryLabelValidator(data)
    }
  }


   //对于多项逻辑回归,设置k分类问题可能输出的类数目;默认是二项逻辑回归
  @Since("1.3.0")
  def setNumClasses(numClasses: Int): this.type = {
    require(numClasses > 1)
    //计算预测器的数量,等于类别数目-1.比如二分类,就只需要一个预测器就可以。
    numOfLinearPredictor = numClasses - 1
    //如果是多分类问题
    if (numClasses > 2) {
      //如果是多分类问题,需要设置优化器为多分类梯度计算器
      optimizer.setGradient(new LogisticGradient(numClasses))
    }
    this
  }

  override protected def createModel(weights: Vector, intercept: Double) = {
    if (numOfLinearPredictor == 1) {
      //二分类问题,使用一个预测器
      new LogisticRegressionModel(weights, intercept)
    } else {
      //多分类问题,
      new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
    }
  }

  /**
   * Run Logistic Regression with the configured parameters on an input RDD
   * of LabeledPoint entries.
   *
   * If a known updater is used calls the ml implementation, to avoid
   * applying a regularization penalty to the intercept, otherwise
   * defaults to the mllib implementation. If more than two classes
   * or feature scaling is disabled, always uses mllib implementation.
   * If using ml implementation, uses ml code to generate initial weights.
   */

  override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = {
    run(input, generateInitialWeights(input), userSuppliedWeights = false)
  }

  /**
   * Run Logistic Regression with the configured parameters on an input RDD
   * of LabeledPoint entries starting from the initial weights provided.
   *
   * If a known updater is used calls the ml implementation, to avoid
   * applying a regularization penalty to the intercept, otherwise
   * defaults to the mllib implementation. If more than two classes
   * or feature scaling is disabled, always uses mllib implementation.
   * Uses user provided weights.
   *
   * In the ml LogisticRegression implementation, the number of corrections
   * used in the LBFGS update can not be configured. So `optimizer.setNumCorrections()`
   * will have no effect if we fall into that route.
   */
   //问题
  override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = {
    run(input, initialWeights, userSuppliedWeights = 
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值