spark mllib源码分析之L-BFGS(一)

相关文章
spark正则化
其他源码分析文章
spark mllib源码分析之OWLQN
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)

1. 使用

spark给出的example中涉及到LBFGS有两个,分别是LBFGSExample.scala和LogisticRegressionWithLBFGSExample.scala,第一个是直接使用LBFGS直接训练,需要指定一系列优化参数,优点是比较灵活,可以自己控制的参数较多。后者使用了LogisticRegressionWithLBFGS,只能设置class的个数,其他参数都是固定的,其实就是将第一个中自己能控制的参数,都指定了默认值,适合刚开始时学习。
下面会以第二个为例,因为其中封装了第一个。

// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)
// Run training algorithm to build the model
val model = new LogisticRegressionWithLBFGS()
      .setNumClasses(10)
      .run(training)

设置包含10分类,调节训练和验证数据集的比例

2. 逻辑回归与L-BFGS算法

这一节简单介绍逻辑回归和L-BFGS算法原理,以便于代码实现相互对照,但是不会做严格的数学推理。

2.1. softmax

这里贴两页代码注释中给出的PPT(https://www.slideshare.net/dbtsai/2014-0620-mlor-36132297
这里写图片描述
这里写图片描述
简单解释下PPT中的内容,第二页似然函数最后的形式中,α只有当样本的label是0的时候取1,第一项为0,其他的label取0,也就是第一项是有值的。第一项虽然在累加号里面,但是我们注意权重w的下标yk,这意味着当累加索引k从0到N循环时,只有其与样本label y(label也是从1到N)相等时,x*w才会被计算,在代码中这项是marginY。
对于第二项,loss和gradient计算中含有指数运算,这部分累加在代码中是margins变量,如果数据中存在异常点,对应到指数如果超过了709.78,就会溢出,导致训练失败,这里做了一点trick,在涉及到指数计算的地方,会先判断计算出的指数部分是否大于0,如果大于0,直接在后面加指数部分,相当于在log里面再除掉,这样就不会溢出了。

log(1+i=1Nexp(xwi))=log(exp(m)(1+i=1Nexp(x
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值