相关文章
spark正则化
其他源码分析文章
spark mllib源码分析之OWLQN
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)
1. 使用
spark给出的example中涉及到LBFGS有两个,分别是LBFGSExample.scala和LogisticRegressionWithLBFGSExample.scala,第一个是直接使用LBFGS直接训练,需要指定一系列优化参数,优点是比较灵活,可以自己控制的参数较多。后者使用了LogisticRegressionWithLBFGS,只能设置class的个数,其他参数都是固定的,其实就是将第一个中自己能控制的参数,都指定了默认值,适合刚开始时学习。
下面会以第二个为例,因为其中封装了第一个。
// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.run(training)
设置包含10分类,调节训练和验证数据集的比例
2. 逻辑回归与L-BFGS算法
这一节简单介绍逻辑回归和L-BFGS算法原理,以便于代码实现相互对照,但是不会做严格的数学推理。
2.1. softmax
这里贴两页代码注释中给出的PPT(https://www.slideshare.net/dbtsai/2014-0620-mlor-36132297)
简单解释下PPT中的内容,第二页似然函数最后的形式中,α只有当样本的label是0的时候取1,第一项为0,其他的label取0,也就是第一项是有值的。第一项虽然在累加号里面,但是我们注意权重w的下标yk,这意味着当累加索引k从0到N循环时,只有其与样本label y(label也是从1到N)相等时,x*w才会被计算,在代码中这项是marginY。
对于第二项,loss和gradient计算中含有指数运算,这部分累加在代码中是margins变量,如果数据中存在异常点,对应到指数如果超过了709.78,就会溢出,导致训练失败,这里做了一点trick,在涉及到指数计算的地方,会先判断计算出的指数部分是否大于0,如果大于0,直接在后面加指数部分,相当于在log里面再除掉,这样就不会溢出了。