RDD.treeAggregate 的用法

原文链接:http://stackoverflow.com/questions/29860635/how-to-interpret-rdd-treeaggregate

Spark 源码:GradientDescent 中函数 runMiniBatchSGD下,有如下一段代码:

val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) // 数据抽样
        .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
          seqOp = (c, v) => {
            // c: (grad, loss, count), v: (label, features)
            val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
            (c._1, c._2 + l, c._3 + 1)
          },
          combOp = (c1, c2) => {
            // c: (grad, loss, count)
            (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
          })

stackflow 中有人给出了解释:
treeAggregate is a specialized implementation of aggregate that iteratively applies the combine function to a subset of partitions.

This is done in order to prevent returning all partial results to the driver where a single pass reduce would take place as the classic aggregate does.

For all practical purposes, treeAggregate follows the same principle than aggregate explained in this answer: Explain the aggregate functionality in Python with the exception that it takes an extra parameter to indicate the depth of the partial aggregation level.

Let me try to explain what’s going on here specifically:

For aggregate, we need a zero, a combiner function and a reduce function. aggregate uses currying to specify the zero value independently of the combine and reduce functions.

We can then dissect the above function like this . Hopefully that helps understanding:

val Zero: (BDV, Double, Long) = (BDV.zeros[Double](n), 0.0, 0L)
val combinerFunction: ((BDV, Double, Long), (??, ??)) => (BDV, Double, Long)  =  (c, v) => {
        // c: (grad, loss, count), v: (label, features)
        val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
        (c._1, c._2 + l, c._3 + 1)
val reducerFunction: ((BDV, Double, Long),(BDV, Double, Long)) => (BDV, Double, Long) = (c1, c2) => {
        // c: (grad, loss, count)
        (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
      }

Then we can rewrite the call to treeAggregate in a more digestable form:

val (gradientSum, lossSum, miniBatchSize) = treeAggregate(Zero)(combinerFunction, reducerFunction)

This form will ‘extract’ the resulting tuple into the named values gradientSum, lossSum, miniBatchSize for further usage.

Note that treeAggregate takes an additional parameter depth which is declared with a default value depth = 2, thus, as it’s not provided in this particular call, it will take that default value.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值