spark mllib源码分析之L-BFGS(二)

本文深入探讨Spark MLlib中L-BFGS优化器的实现,详细分析训练数据结构、损失函数、State、近似逆海森矩阵以及训练过程,包括下降方向计算、步长确定、权重调整和海森矩阵更新。
摘要由CSDN通过智能技术生成

相关文章
spark源码分析之L-BFGS(一)
线搜索
spark正则化
spark mllib源码分析之OWLQN
其他源码分析文章
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)

4.4. optimize

我们的optimizer使用的是LBFGS,其optimize函数

  override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
    val (weights, _) = LBFGS.runLBFGS(
      data,
      gradient,   //LogisticGradient
      updater,    //SquaredL2Updater
      numCorrections,  //default 10
      convergenceTol,   //default 1E-6
      maxNumIterations,  //default 100
      regParam,          //0.0
      initialWeights)
    weights
  }

其默认参数都封装在mllib的LBFGS中,实际的训练过程在object LBFGS的runLBFGS函数中

4.4.1. 训练使用的数据结构

4.4.1.1. 损失函数

首先将loss和gradient的计算封装成CostFun类,方便在LBFGS迭代过程中计算

  /**
   * CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
   * at a particular point (weights). It's used in Breeze's convex optimization routines.
   */
  private class CostFun(
    data: RDD[(Double, Vector)],
    gradient: Gradient,
    updater: Updater,
    regParam: Double,
    numExamples: Long) extends DiffFunction[BDV[Double]] {

    override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
      // Have a local copy to avoid the serialization of CostFun object which is not serializable.
      val w = Vectors.fromBreeze(weights)
      val n = w.size
      val bcW = data.context.broadcast(w)
      val localGradient = gradient

      val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))(
      //executor ops,计算每个partition上的grad和loss,具体参见treeAggregate的用法
          seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
            val l = localGradient.compute(
              features, label, bcW.value, grad)
            (grad, loss + l)
          },
          //driver ops,计算所有分区返回结果
          combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
            axpy(1.0, grad2, grad1)
            (grad1, loss1 + loss2)
          })

      /**
       * regVal is sum of weight squares if it's L2 updater;
       * for other updater, the same logic is followed.
       */
       //计算loss
      val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2

      val loss = lossSum / numExamples + regVal
      /**
       * It will return the gradient part of regularization using updater.
       *
       * Given the input parameters, the updater basically does the following,
       *
       * w' = w - thisIterStepSize * (gradient + regGradient(w))
       * Note that regGradient is function of w
       *
       * If we set gradient = 0, thisIterStepSize = 1, then
       *
       * regGradient(w) = w - w'
       *
       * TODO: We need to clean it up by separating the logic of regularization out
       *       from updater to regularizer.
       */
      // The following gradientTotal is actually the regularization part of gradient.
      // Will add the gradientSum computed from the data with weights in the next step.
      //计算gradient
      val gradientTotal = w.copy
      axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)

      // gradientTotal = gradientSum / numExamples + gradientTotal
      axpy(1.0 / numExamples, gradientSum, gradientTotal)

      (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]])
    }
  }

4.4.1.2. State

对迭代过程中的参数进行简单封装,放在State中

 /**
   * Tracks the information about the optimizer, including the current point, its value, gradient, and then any history.
   * Also includes information for checking converge
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值