相关文章
spark源码分析之L-BFGS(一)
线搜索
spark正则化
spark mllib源码分析之OWLQN
其他源码分析文章
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)
4.4. optimize
我们的optimizer使用的是LBFGS,其optimize函数
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
val (weights, _) = LBFGS.runLBFGS(
data,
gradient, //LogisticGradient
updater, //SquaredL2Updater
numCorrections, //default 10
convergenceTol, //default 1E-6
maxNumIterations, //default 100
regParam, //0.0
initialWeights)
weights
}
其默认参数都封装在mllib的LBFGS中,实际的训练过程在object LBFGS的runLBFGS函数中
4.4.1. 训练使用的数据结构
4.4.1.1. 损失函数
首先将loss和gradient的计算封装成CostFun类,方便在LBFGS迭代过程中计算
/**
* CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient
* at a particular point (weights). It's used in Breeze's convex optimization routines.
*/
private class CostFun(
data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
regParam: Double,
numExamples: Long) extends DiffFunction[BDV[Double]] {
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
val w = Vectors.fromBreeze(weights)
val n = w.size
val bcW = data.context.broadcast(w)
val localGradient = gradient
val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))(
//executor ops,计算每个partition上的grad和loss,具体参见treeAggregate的用法
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = localGradient.compute(
features, label, bcW.value, grad)
(grad, loss + l)
},
//driver ops,计算所有分区返回结果
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
axpy(1.0, grad2, grad1)
(grad1, loss1 + loss2)
})
/**
* regVal is sum of weight squares if it's L2 updater;
* for other updater, the same logic is followed.
*/
//计算loss
val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2
val loss = lossSum / numExamples + regVal
/**
* It will return the gradient part of regularization using updater.
*
* Given the input parameters, the updater basically does the following,
*
* w' = w - thisIterStepSize * (gradient + regGradient(w))
* Note that regGradient is function of w
*
* If we set gradient = 0, thisIterStepSize = 1, then
*
* regGradient(w) = w - w'
*
* TODO: We need to clean it up by separating the logic of regularization out
* from updater to regularizer.
*/
// The following gradientTotal is actually the regularization part of gradient.
// Will add the gradientSum computed from the data with weights in the next step.
//计算gradient
val gradientTotal = w.copy
axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal)
// gradientTotal = gradientSum / numExamples + gradientTotal
axpy(1.0 / numExamples, gradientSum, gradientTotal)
(loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]])
}
}
4.4.1.2. State
对迭代过程中的参数进行简单封装,放在State中
/**
* Tracks the information about the optimizer, including the current point, its value, gradient, and then any history.
* Also includes information for checking converge